1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/types.pb.h" |
23 | #include "tensorflow/core/lib/gtl/array_slice.h" |
24 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
25 | #include "tensorflow/core/lib/strings/str_util.h" |
26 | #include "tensorflow/core/platform/errors.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/platform/macros.h" |
29 | #include "tensorflow/core/platform/statusor.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // START_SKIP_DOXYGEN |
34 | template <class Shape> |
35 | class TensorShapeIter; |
36 | class TensorShape; |
37 | class TensorShapeProto; |
38 | class PartialTensorShape; |
39 | // END_SKIP_DOXYGEN |
40 | |
41 | /// Internal representation for both TensorShape and PartialTensorShape. |
42 | class TensorShapeRep { |
43 | public: |
44 | ~TensorShapeRep(); |
45 | |
46 | /// Copy the specified shape |
47 | TensorShapeRep(const TensorShapeRep& b); |
48 | void operator=(const TensorShapeRep& b); |
49 | |
50 | /// Move the specified shape. After moving, `b` is safe for destruction and |
51 | // can be reassigned into, but its dimensions and number of elements can be |
52 | // nonsensical (e.g., negative dimension sizes, or number of elements not |
53 | // properly recomputed). |
54 | TensorShapeRep(TensorShapeRep&& b); |
55 | void operator=(TensorShapeRep&& b); |
56 | |
57 | /// Clear a tensor shape, producing the scalar shape. |
58 | void Clear(); |
59 | |
60 | // Maximum number of dimensions in a tensor. |
61 | // It's 254 because 255 = kUnknownRank is used to represent unknown rank. |
62 | static constexpr int MaxDimensions() { return 254; } |
63 | |
64 | /// \brief Returns the number of elements in the tensor. |
65 | /// |
66 | /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` |
67 | /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully |
68 | /// defined. |
69 | int64_t num_elements() const { return num_elements_; } |
70 | |
71 | /// For error messages. |
72 | std::string DebugString() const; |
73 | static std::string DebugString(const TensorShapeProto& proto); |
74 | |
75 | protected: |
76 | // Constructable only via TensorShapeBase |
77 | TensorShapeRep() = default; |
78 | |
79 | void ClearAllButDataType(); |
80 | |
81 | // We use 16 bytes to represent a TensorShape. Because we need to |
82 | // be able to support full 64-bit dimension sizes and an arbitrary |
83 | // number of dimensions for a Tensor, but most tensor dimensions are |
84 | // significantly smaller than 64 bits and most tensors are 1, 2, or 3 |
85 | // dimensions, we have several representations. |
86 | // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 |
87 | // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 |
88 | // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using |
89 | // an out of line vector. |
90 | // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown. |
91 | // This value is not allowed in TensorShape either for format compatibility. |
92 | struct Rep16 { |
93 | uint16 dims_[6]; |
94 | }; |
95 | struct Rep32 { |
96 | uint32 dims_[3]; |
97 | }; |
98 | struct Rep64 { |
99 | gtl::InlinedVector<int64_t, 4>* dims_; |
100 | }; |
101 | |
102 | // We use the max value of uint16 or uint32 to represent unknown shapes, so |
103 | // the maximum representable valid shape in these representations is one less. |
104 | static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1; |
105 | static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1; |
106 | static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max(); |
107 | static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max(); |
108 | |
109 | Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); } |
110 | Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); } |
111 | Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); } |
112 | |
113 | const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); } |
114 | const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); } |
115 | const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); } |
116 | |
117 | enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; |
118 | |
119 | // Since we have a convenient extra byte available, we allow the |
120 | // Tensor class to store an 8-bit value in this extra storage. This |
121 | // allows it to store the Tensor's datatype enum value here and avoid |
122 | // an extra word of storage. |
123 | friend class Tensor; |
124 | friend class TensorShapeTestHelper; |
125 | DataType data_type() const { return static_cast<DataType>(buf()[13]); } |
126 | void set_data_type(DataType dt) { |
127 | // We only have 8 bits available to store DataType, so make sure it fits |
128 | DCHECK_LT(static_cast<uint32>(dt), 256u); |
129 | buf()[13] = static_cast<uint8>(dt); |
130 | } |
131 | |
132 | // We store the number of dimensions in byte 14, and the RepTag in byte 15. |
133 | // Bytes [0..13] vary depending on the representation. |
134 | // A value of 255 indicates unknown rank in the PartialTensorShape case. |
135 | static constexpr uint8 kUnknownRank = 255; |
136 | uint8 ndims_byte() const { return buf()[14]; } |
137 | void set_ndims_byte(uint8 nd) { buf()[14] = nd; } |
138 | |
139 | RepTag tag() const { return static_cast<RepTag>(buf()[15]); } |
140 | void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); } |
141 | |
142 | void set_num_elements(int64_t n) { num_elements_ = n; } |
143 | |
144 | private: |
145 | void DestructorOutOfLine(); |
146 | void SlowCopyFrom(const TensorShapeRep& b); |
147 | |
148 | uint8* buf() { return &u_.buf[0]; } |
149 | const uint8* buf() const { return &u_.buf[0]; } |
150 | |
151 | union { |
152 | uint8 buf[16]; |
153 | // Force data to be aligned enough for a pointer. |
154 | Rep64* unused_aligner; |
155 | } u_; |
156 | int64_t num_elements_; |
157 | }; |
158 | |
159 | /// Base class for TensorShape and PartialTensorShape. |
160 | /// The class is templatized by either TensorShape or PartialTensorShape to |
161 | /// allow skipping known/unknown checks in the TensorShape case, but the |
162 | /// representation is shared exactly for fast conversion. |
163 | template <class Shape> |
164 | class TensorShapeBase : public TensorShapeRep { |
165 | public: |
166 | /// \brief Construct a `TensorShapeBase` from the provided sizes. |
167 | /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) |
168 | explicit TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes); |
169 | TensorShapeBase(std::initializer_list<int64_t> dim_sizes) |
170 | : TensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes)) {} |
171 | |
172 | /// Construct an empty TensorShape, or an unknown rank PartialTensorShape |
173 | TensorShapeBase(); |
174 | |
175 | // Cannot be made explicit because we rely on conversion between proto and |
176 | // `TensorShapeBase` throughtout the codebase (needs bigger cleanup) |
177 | TensorShapeBase(const TensorShapeProto& proto); |
178 | |
179 | // These factory methods should be used instead of the constructors that take |
180 | // an array of sizes if calling code cannot validate that the sizes specify a |
181 | // valid `TensorShape`. |
182 | // The value in `*out` is valid iff the returned value is `Status::OK`. |
183 | static Status BuildTensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes, |
184 | TensorShapeBase* out); |
185 | static Status BuildTensorShapeBase(std::initializer_list<int64_t> dim_sizes, |
186 | TensorShapeBase* out) { |
187 | return BuildTensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes), out); |
188 | } |
189 | static Status BuildTensorShapeBase(const TensorShapeProto& proto, |
190 | TensorShapeBase* out); |
191 | |
192 | /// Returns `true` iff `proto` is a valid tensor shape. |
193 | // For TensorShape, the proto shape must be fully defined. |
194 | static bool IsValid(const TensorShapeProto& proto); |
195 | |
196 | /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error |
197 | /// status otherwise. |
198 | static Status IsValidShape(const TensorShapeProto& proto); |
199 | |
200 | /// Returns `true` iff this is a valid tensor shape. |
201 | bool IsValid(); |
202 | |
203 | /// \brief Add a dimension to the end ("inner-most"). |
204 | /// REQUIRES: `size >= 0` |
205 | void AddDim(int64_t size); |
206 | |
207 | /// Same as `AddDim` but returns a `Status`. |
208 | /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes. |
209 | Status AddDimWithStatus(int64_t size); |
210 | |
211 | /// Appends all the dimensions from `shape`. |
212 | void AppendShape(const TensorShapeBase& shape); |
213 | |
214 | /// Same as `RemoveDim` but returns a `Status`. |
215 | /// Use if you cannot validate all invariants, to prevent `CHECK`-fail. |
216 | Status AppendShapeWithStatus(const TensorShapeBase& shape); |
217 | |
218 | /// \brief Insert a dimension somewhere in the `TensorShape`. |
219 | /// REQUIRES: `0 <= d <= dims()` |
220 | /// REQUIRES: `size >= 0` |
221 | void InsertDim(int d, int64_t size); |
222 | |
223 | /// Same as `InsertDim` but returns a `Status`. |
224 | /// Use if unsure if requirements in `InsertDim` are satistified, to prevent |
225 | /// `CHECK`-fail crashes. |
226 | Status InsertDimWithStatus(int d, int64_t size); |
227 | |
228 | /// \brief Modifies the size of the dimension `d` to be `size` |
229 | /// REQUIRES: `0 <= d < dims()` |
230 | /// REQUIRES: `size >= 0` |
231 | void set_dim(int d, int64_t size); |
232 | |
233 | /// Same as `set_dim` but returns a `Status`. |
234 | /// Use if unsure if requirements in `set_dim` are satistified, to prevent |
235 | /// `CHECK`-fail crashes. |
236 | Status SetDimWithStatus(int d, int64_t size); |
237 | |
238 | /// \brief Removes dimension `d` from the `TensorShape`. |
239 | /// REQUIRES: `0 <= d < dims()` |
240 | void RemoveDim(int d) { |
241 | CHECK_GE(d, 0); |
242 | RemoveDimRange(d, d + 1); |
243 | } |
244 | |
245 | /// Same as `RemoveDim` but returns a `Status`. |
246 | /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes. |
247 | Status RemoveDimWithStatus(int64_t d) { |
248 | if (TF_PREDICT_FALSE(d < 0)) { |
249 | return errors::Internal( |
250 | "Expected dimension index to be non-negative, got " , d); |
251 | } |
252 | return RemoveDimRangeWithStatus(d, d + 1); |
253 | } |
254 | |
255 | /// \brief Removes last `n` dimensions from the `TensorShape`. |
256 | /// REQUIRES: `0 <= n <= dims()` |
257 | void RemoveLastDims(int n) { |
258 | CHECK_LE(n, dims()); |
259 | RemoveDimRange(dims() - n, dims()); |
260 | } |
261 | |
262 | /// Same as `RemoveLastDims` but returns a `Status`. |
263 | /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes. |
264 | Status RemoveLastDimsWithStatus(int64_t n) { |
265 | if (TF_PREDICT_FALSE(n < dims())) { |
266 | return errors::Internal("Expected dimension index to be at most " , dims(), |
267 | " got " , n); |
268 | } |
269 | return RemoveDimRangeWithStatus(dims() - n, dims()); |
270 | } |
271 | |
272 | /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. |
273 | /// Negative values of `end` are interpreted as `dims() + end + 1` (as in |
274 | /// Python). The same is true for negative values of `begin`. |
275 | /// REQUIRES: `-(dims()+1) <= begin <= dims()` |
276 | /// REQUIRES: `-(dims()+1) <= end <= dims()` |
277 | void RemoveDimRange(int begin, int end); |
278 | |
279 | /// Same as `RemoveDimRange` but returns a `Status`. |
280 | /// Use if unsure if requirements in `RemoveDimRange` are satistified, to |
281 | /// prevent `CHECK`-fail crashes. |
282 | Status RemoveDimRangeWithStatus(int begin, int end); |
283 | |
284 | /// Return whether the rank is unknown |
285 | bool unknown_rank() const { |
286 | return kIsPartial && ndims_byte() == kUnknownRank; |
287 | } |
288 | |
289 | /// Return the number of dimensions in the tensor. |
290 | /// Can be -1 meaning unknown rank for PartialTensorShape. |
291 | int dims() const { |
292 | uint8 dims = ndims_byte(); |
293 | return kIsPartial && dims == kUnknownRank ? -1 : dims; |
294 | } |
295 | |
296 | /// \brief Returns the number of elements in dimension `d`. |
297 | /// REQUIRES: `0 <= d < dims()` |
298 | // TODO(touts): Rename to `dimension()` to match |
299 | // `Eigen::Tensor::dimension()`? |
300 | int64_t dim_size(int d) const; |
301 | |
302 | /// Returns sizes of all dimensions. |
303 | // Returns an empty list for unknown rank PartialTensorShape. |
304 | gtl::InlinedVector<int64_t, 4> dim_sizes() const; |
305 | |
306 | /// Return true iff the rank and all of the dimensions are well defined |
307 | // TODO(irving): Rename to is_fully_defined now that it's fast. |
308 | bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } |
309 | |
310 | /// Fill `*proto` from `*this`. |
311 | void AsProto(TensorShapeProto* proto) const; |
312 | TensorShapeProto AsProto() const; |
313 | |
314 | /// For iterating through the dimensions. |
315 | TensorShapeIter<Shape> begin() const; |
316 | TensorShapeIter<Shape> end() const; |
317 | |
318 | protected: |
319 | // Optimized constructor for a shape representing an empty vector. |
320 | // |
321 | // This constructor is provided to optimize the default constructor for |
322 | // `Tensor`. |
323 | explicit TensorShapeBase(DataType dt); |
324 | |
325 | private: |
326 | Status RecomputeNumElements(); |
327 | Status InitDims(gtl::ArraySlice<int64_t> dim_sizes); |
328 | |
329 | // True for PartialTensorShape, false for TensorShape |
330 | static constexpr bool kIsPartial = |
331 | std::is_same<Shape, PartialTensorShape>::value; |
332 | static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value, |
333 | "Shape is neither TensorShape nor PartialTensorShape" ); |
334 | |
335 | // Used by AddDim and MakeShapeHelper. Does no error checking. |
336 | void UnsafeAddDim(int64_t size, int64_t new_num_elements); |
337 | |
338 | // For use by TensorShapeUtils::MakeShape |
339 | template <class T, class S> |
340 | friend Status MakeShapeHelper(const T*, int64_t, S*); |
341 | }; |
342 | |
343 | /// Outputs `TensorShapeBase` to `std::ostream`. |
344 | template <typename Shape> |
345 | std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) { |
346 | return os << tsb.DebugString(); |
347 | } |
348 | |
349 | /// Represents the shape of a Tensor. |
350 | /// |
351 | /// A tensor's shape is denoted by its number of dimensions and a size for each |
352 | /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have |
353 | /// a shape of 2-D, [3,4]. |
354 | /// |
355 | /// If you know the exact shape of your Tensor when you create the TensorShape |
356 | /// object, you can specify it then, or you can create a TensorShape with |
357 | /// zero dimensions and one element, and call AddDim() to add dimensions later. |
358 | class TensorShape : public TensorShapeBase<TensorShape> { |
359 | public: |
360 | using TensorShapeBase<TensorShape>::TensorShapeBase; |
361 | |
362 | // These factory methods should be used instead of the constructors that take |
363 | // an array of sizes if calling code cannot validate that the sizes specify a |
364 | // valid `TensorShape`. |
365 | // The value in `*out` is valid iff the returned value is `Status::OK`. |
366 | static Status BuildTensorShape(gtl::ArraySlice<int64_t> dim_sizes, |
367 | TensorShape* out) { |
368 | return BuildTensorShapeBase(dim_sizes, out); |
369 | } |
370 | static Status BuildTensorShape(std::initializer_list<int64_t> dim_sizes, |
371 | TensorShape* out) { |
372 | return BuildTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out); |
373 | } |
374 | static Status BuildTensorShape(const TensorShapeProto& proto, |
375 | TensorShape* out) { |
376 | return BuildTensorShapeBase(proto, out); |
377 | } |
378 | |
379 | static StatusOr<TensorShape> BuildTensorShape(const TensorShapeProto& proto) { |
380 | TensorShape out; |
381 | TF_RETURN_IF_ERROR(BuildTensorShape(proto, &out)); |
382 | return out; |
383 | } |
384 | |
385 | /// Allow a TensorShape to be used as a PartialTensorShape without copying |
386 | operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) |
387 | |
388 | /// Returns true if `*this` and `b` have the same sizes. Ignores |
389 | /// dimension names. |
390 | bool IsSameSize(const TensorShape& b) const; |
391 | bool operator==(const TensorShape& b) const { return IsSameSize(b); } |
392 | bool operator!=(const TensorShape& b) const { return !IsSameSize(b); } |
393 | |
394 | /// Fill `*dsizes` from `*this`. |
395 | /// Notice: Using IndexType=int32 in combination with To32Bit() can |
396 | /// significantly improve performance on GPU. |
397 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
398 | Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const; |
399 | |
400 | // Same as `AsEigenDSizes()` but returns a `Status` instead. |
401 | // Use this method to surface error to user instead of crashing if `NDMIS` is |
402 | // not equal to `dims()`. |
403 | // Caller must take ownership of `out`. |
404 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
405 | Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const; |
406 | |
407 | /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in |
408 | /// which case we pad the rest of the sizes with 1. |
409 | /// Notice: Using IndexType=int32 in combination with To32Bit() can |
410 | /// significantly improve performance on GPU. |
411 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
412 | Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const; |
413 | |
414 | // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead. |
415 | // Use this method to surface error to user instead of crashing if `NDMIS` is |
416 | // not equal to `dims()`. |
417 | // Caller must take ownership of `out`. |
418 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
419 | Status AsEigenDSizesWithPaddingWithStatus( |
420 | Eigen::DSizes<IndexType, NDIMS>* out) const; |
421 | |
422 | private: |
423 | // These CHECK fail to ease debugging. |
424 | // REQUIRES: dims() == NDIMS |
425 | void CheckDimsEqual(int NDIMS) const; |
426 | // REQUIRES: dims() <= NDIMS |
427 | void CheckDimsAtMost(int NDIMS) const; |
428 | |
429 | // Fill output from `*this`. |
430 | // Helper method for common code between `AsEigenDSize()` and |
431 | // `AsEigenDSizeWithStatus()`. |
432 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
433 | Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const; |
434 | |
435 | // Fill output from `*this`. |
436 | // Helper method for common code between `AsEigenDSizesWithPadding()` and |
437 | // `AsEigenDSizeWithPaddingWithStatus()`. |
438 | template <int NDIMS, typename IndexType = Eigen::DenseIndex> |
439 | Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const; |
440 | |
441 | // For access to TensorShapeBase(DataType). |
442 | friend class Tensor; |
443 | }; |
444 | |
445 | /// Outputs `TensorShapeBase` to `std::ostream`. |
446 | inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) { |
447 | return os << ts.DebugString(); |
448 | } |
449 | |
450 | /// Represents the value of one dimension in a TensorShape. |
451 | struct TensorShapeDim { |
452 | explicit TensorShapeDim(int64_t s) : size(s) {} |
453 | int64_t size; |
454 | }; |
455 | |
456 | // START_SKIP_DOXYGEN |
457 | template <class Shape> |
458 | class TensorShapeIter { |
459 | public: |
460 | TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} |
461 | bool operator==(const TensorShapeIter& rhs) { |
462 | DCHECK(shape_ == rhs.shape_); |
463 | return d_ == rhs.d_; |
464 | } |
465 | bool operator!=(const TensorShapeIter& rhs) { |
466 | DCHECK(shape_ == rhs.shape_); |
467 | return d_ != rhs.d_; |
468 | } |
469 | void operator++() { ++d_; } |
470 | TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } |
471 | |
472 | private: |
473 | const Shape* shape_; |
474 | int d_; |
475 | }; |
476 | // END_SKIP_DOXYGEN |
477 | |
478 | /// \brief Static helper routines for `TensorShape`. Includes a few common |
479 | /// predicates on a tensor shape. |
480 | class TensorShapeUtils { |
481 | public: |
482 | static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } |
483 | |
484 | static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } |
485 | |
486 | static bool IsVectorOrHigher(const TensorShape& shape) { |
487 | return shape.dims() >= 1; |
488 | } |
489 | |
490 | static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } |
491 | |
492 | static bool IsSquareMatrix(const TensorShape& shape) { |
493 | return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); |
494 | } |
495 | |
496 | static bool IsMatrixOrHigher(const TensorShape& shape) { |
497 | return shape.dims() >= 2; |
498 | } |
499 | |
500 | /// \brief Returns a `TensorShape` whose dimensions are |
501 | /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. |
502 | static Status MakeShape(const int32* dims, int64_t n, TensorShape* out); |
503 | static Status MakeShape(const int64_t* dims, int64_t n, TensorShape* out); |
504 | static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out); |
505 | static Status MakeShape(gtl::ArraySlice<int64_t> shape, TensorShape* out); |
506 | static Status MakeShape(const int32* dims, int64_t n, |
507 | PartialTensorShape* out); |
508 | static Status MakeShape(const int64_t* dims, int64_t n, |
509 | PartialTensorShape* out); |
510 | static Status MakeShape(gtl::ArraySlice<int32> shape, |
511 | PartialTensorShape* out); |
512 | static Status MakeShape(gtl::ArraySlice<int64_t> shape, |
513 | PartialTensorShape* out); |
514 | |
515 | static std::string ShapeListString( |
516 | const gtl::ArraySlice<TensorShape>& shapes); |
517 | |
518 | /// \brief Returns true iff `shape` starts with `prefix`. |
519 | static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); |
520 | |
521 | /// \brief Returns true iff `shape` ends with `suffix`. |
522 | static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); |
523 | |
524 | /// \brief Returns the product of values in an int64 array, |
525 | /// or a failing Status if the array represents a value larger than |
526 | /// a `TensorShape` can hold. |
527 | static Status NumElements(gtl::ArraySlice<int64_t> shape, |
528 | int64_t* num_elements); |
529 | }; |
530 | |
531 | /// Manages the partially known dimensions of a Tensor and their sizes. |
532 | class PartialTensorShape : public TensorShapeBase<PartialTensorShape> { |
533 | public: |
534 | PartialTensorShape() {} |
535 | using TensorShapeBase<PartialTensorShape>::TensorShapeBase; |
536 | |
537 | // These factory methods should be used instead of the constructors that take |
538 | // an array of sizes if calling code cannot validate that the sizes specify a |
539 | // valid `PartialTensorShape`. |
540 | // The value in `*out` is valid iff the returned value is `Status::OK`. |
541 | static Status BuildPartialTensorShape(gtl::ArraySlice<int64_t> dim_sizes, |
542 | PartialTensorShape* out) { |
543 | return BuildTensorShapeBase(dim_sizes, out); |
544 | } |
545 | static Status BuildPartialTensorShape( |
546 | std::initializer_list<int64_t> dim_sizes, PartialTensorShape* out) { |
547 | return BuildPartialTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out); |
548 | } |
549 | static Status BuildPartialTensorShape(const TensorShapeProto& proto, |
550 | PartialTensorShape* out) { |
551 | return BuildTensorShapeBase(proto, out); |
552 | } |
553 | |
554 | static StatusOr<PartialTensorShape> BuildPartialTensorShape( |
555 | const TensorShapeProto& proto) { |
556 | PartialTensorShape out; |
557 | TF_RETURN_IF_ERROR(BuildTensorShapeBase(proto, &out)); |
558 | return out; |
559 | } |
560 | |
561 | /// Add a dimension to the end ("inner-most"), returns a new |
562 | /// PartialTensorShape. |
563 | /// REQUIRES: `size >= -1`, where -1 means unknown |
564 | PartialTensorShape Concatenate(int64_t size) const; |
565 | |
566 | /// Similar to `Concatenate` but returning `Status`. |
567 | /// Use if calling code cannot validate all requirements and if `CHECK`-fails |
568 | /// are to be avoided. |
569 | Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const; |
570 | |
571 | /// Appends all the dimensions from `shape`. Returns a new |
572 | /// PartialTensorShape. |
573 | PartialTensorShape Concatenate(const PartialTensorShape& shape) const; |
574 | |
575 | /// Similar to `Concatenate` but returning `Status`. |
576 | /// Use if calling code cannot validate all requirements and if `CHECK`-fails |
577 | /// are to be avoided. |
578 | Status ConcatenateWithStatus(const PartialTensorShape& shape, |
579 | PartialTensorShape* out) const; |
580 | |
581 | /// Merges all the dimensions from `shape`. Returns |
582 | /// `InvalidArgument` error if either `shape` has a different rank |
583 | /// or if any of the dimensions are incompatible. |
584 | Status MergeWith(const PartialTensorShape& shape, |
585 | PartialTensorShape* result) const; |
586 | |
587 | /// Exact equality test. Returns true iff the ranks match (i.e., both are |
588 | /// unknown, or both are known and equal), and all dimensions are equal (i.e., |
589 | /// both dimensions are known, or both are known and equal). This is a |
590 | /// stronger condition that IsCompatibleWith. |
591 | bool IsIdenticalTo(const PartialTensorShape& shape) const; |
592 | |
593 | /// Return true iff the ranks match, and if the |
594 | /// dimensions all either match or one is unknown. |
595 | bool IsCompatibleWith(const PartialTensorShape& shape) const; |
596 | |
597 | // Fill `*shape` from `*this`. |
598 | // If `*this` is not fully defined, returns false and |
599 | // `*shape` is left in an intermediate state. Otherwise |
600 | // returns true. |
601 | bool AsTensorShape(TensorShape* shape) const; |
602 | |
603 | /// \brief Returns a `PartialTensorShape` whose dimensions are |
604 | /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are |
605 | /// considered "unknown". |
606 | template <class T> |
607 | static Status MakePartialShape(const T* dims, int n, |
608 | PartialTensorShape* out) { |
609 | return TensorShapeUtils::MakeShape(dims, n, out); |
610 | } |
611 | }; |
612 | |
613 | /// \brief Static helper routines for `PartialTensorShape`. Includes a few |
614 | /// common predicates on a partially known tensor shape. |
615 | class PartialTensorShapeUtils { |
616 | public: |
617 | static std::string PartialShapeListString( |
618 | const gtl::ArraySlice<PartialTensorShape>& shapes); |
619 | |
620 | static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0, |
621 | const gtl::ArraySlice<PartialTensorShape>& shapes1); |
622 | |
623 | static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0, |
624 | const gtl::ArraySlice<PartialTensorShape>& shapes1); |
625 | }; |
626 | |
627 | // ---------------------------------------------------------------------------- |
628 | // Template method implementation details below |
629 | // ---------------------------------------------------------------------------- |
630 | |
631 | template <int NDIMS, typename IndexType> |
632 | Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const { |
633 | Eigen::DSizes<IndexType, NDIMS> dsizes; |
634 | for (int d = 0; d < NDIMS; d++) { |
635 | dsizes[d] = static_cast<IndexType>(dim_size(d)); |
636 | } |
637 | return dsizes; |
638 | } |
639 | |
640 | template <int NDIMS, typename IndexType> |
641 | Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const { |
642 | static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions" ); |
643 | Eigen::DSizes<IndexType, NDIMS> dsizes; |
644 | for (int d = 0; d < dims(); d++) { |
645 | dsizes[d] = static_cast<IndexType>(dim_size(d)); |
646 | } |
647 | for (int d = dims(); d < NDIMS; d++) { |
648 | dsizes[d] = 1; |
649 | } |
650 | return dsizes; |
651 | } |
652 | |
653 | template <int NDIMS, typename IndexType> |
654 | Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const { |
655 | CheckDimsEqual(NDIMS); |
656 | return AsEigenDSizesCopy<NDIMS, IndexType>(); |
657 | } |
658 | |
659 | template <int NDIMS, typename IndexType> |
660 | Status TensorShape::AsEigenDSizesWithStatus( |
661 | Eigen::DSizes<IndexType, NDIMS>* out) const { |
662 | if (TF_PREDICT_FALSE(NDIMS != dims())) { |
663 | return errors::Internal("Asking for tensor of " , NDIMS, |
664 | " dimensions from a tensor of " , dims(), |
665 | " dimensions" ); |
666 | } |
667 | *out = AsEigenDSizesCopy<NDIMS, IndexType>(); |
668 | } |
669 | |
670 | template <int NDIMS, typename IndexType> |
671 | Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const { |
672 | CheckDimsAtMost(NDIMS); |
673 | return AsEigenDSizesCopyAndPad<NDIMS, IndexType>(); |
674 | } |
675 | |
676 | template <int NDIMS, typename IndexType> |
677 | Status TensorShape::AsEigenDSizesWithPaddingWithStatus( |
678 | Eigen::DSizes<IndexType, NDIMS>* out) const { |
679 | if (TF_PREDICT_FALSE(NDIMS < dims())) { |
680 | return errors::Internal("Asking for tensor of at least " , NDIMS, |
681 | " dimensions from a tensor of " , dims(), |
682 | " dimensions" ); |
683 | } |
684 | *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>(); |
685 | } |
686 | |
687 | // ---------------------------------------------------------------------------- |
688 | // Inlining of some performance critical routines |
689 | // ---------------------------------------------------------------------------- |
690 | |
691 | inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { |
692 | num_elements_ = b.num_elements_; |
693 | if (b.tag() != REP_OUT_OF_LINE) { |
694 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
695 | // memcpy above Implicitly does: |
696 | // set_ndims_byte(b.ndims_byte()); |
697 | // set_tag(b.tag()); |
698 | } else { |
699 | set_tag(REP16); // So that SlowCopyFrom does not try to deallocate |
700 | SlowCopyFrom(b); |
701 | } |
702 | } |
703 | |
704 | inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { |
705 | num_elements_ = b.num_elements_; |
706 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
707 | // memcpy above Implicitly does: |
708 | // set_ndims_byte(b.ndims_byte()); |
709 | // set_tag(b.tag()); |
710 | b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. |
711 | } |
712 | |
713 | inline TensorShapeRep::~TensorShapeRep() { |
714 | if (tag() == REP_OUT_OF_LINE) { |
715 | DestructorOutOfLine(); |
716 | } |
717 | } |
718 | |
719 | inline void TensorShapeRep::operator=(const TensorShapeRep& b) { |
720 | num_elements_ = b.num_elements_; |
721 | if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { |
722 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
723 | // memcpy above implicitly also does: |
724 | // set_tag(b.tag()); |
725 | // set_ndims_byte(b.ndims_byte()); |
726 | } else { |
727 | SlowCopyFrom(b); |
728 | } |
729 | } |
730 | |
731 | inline void TensorShapeRep::operator=(TensorShapeRep&& b) { |
732 | if (tag() == REP_OUT_OF_LINE) { |
733 | DestructorOutOfLine(); |
734 | } |
735 | num_elements_ = b.num_elements_; |
736 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
737 | // memcpy above Implicitly does: |
738 | // set_ndims_byte(b.ndims_byte()); |
739 | // set_tag(b.tag()); |
740 | b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. |
741 | } |
742 | |
743 | inline TensorShape::operator const PartialTensorShape&() const { |
744 | // Downcast to the shared representation and upcast to PartialTensorShape |
745 | const TensorShapeRep* rep = this; |
746 | return *static_cast<const PartialTensorShape*>(rep); |
747 | } |
748 | |
749 | template <class Shape> |
750 | inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) { |
751 | set_tag(REP16); |
752 | set_data_type(dt); |
753 | |
754 | // Optimized implementation of InitDims() where the shape is statically known |
755 | // to be {0}. |
756 | set_ndims_byte(1); |
757 | uint16* dst = as16()->dims_; |
758 | *dst = 0; |
759 | set_num_elements(0); |
760 | } |
761 | |
762 | // Declare explicit instantiations in .cc file |
763 | extern template class TensorShapeBase<TensorShape>; |
764 | extern template class TensorShapeBase<PartialTensorShape>; |
765 | |
766 | // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's |
767 | // often used in shape inference. |
768 | struct DtypeAndPartialTensorShape { |
769 | DataType dtype; |
770 | PartialTensorShape shape; |
771 | }; |
772 | |
773 | } // namespace tensorflow |
774 | |
775 | #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
776 | |