1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
18
19#include <cstdint>
20#include <type_traits>
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/allocator.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/framework/tensor_types.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/framework/types.pb.h"
28#include "tensorflow/core/lib/core/refcount.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/core/stringpiece.h"
31#include "tensorflow/core/lib/gtl/inlined_vector.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/mem.h"
35#include "tensorflow/core/platform/types.h"
36
37namespace tensorflow {
38
39// Forward declarations. In particular, we forward declare protos so that their
40// symbols can be removed from .so exports.
41class AllocationDescription;
42class OpKernelContext;
43class Tensor;
44class TensorBuffer;
45class TensorCApi;
46class TensorInterface;
47class TensorCord;
48class TensorDescription;
49class TensorProto;
50class Var;
51
52namespace batch_util {
53Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index);
54Status CopySliceToElement(const Tensor& parent, Tensor* element, int64_t index);
55Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index);
56Status CopyContiguousSlices(const Tensor& src, int64_t src_offset,
57 int64_t dst_offset, int64_t num_slices,
58 Tensor* dst);
59} // namespace batch_util
60
61/// @ingroup core
62
63/// Interface to access the raw ref-counted data buffer.
64class TensorBuffer : public core::RefCounted {
65 public:
66 explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
67 ~TensorBuffer() override {}
68
69 /// \brief data() points to a memory region of size() bytes.
70 ///
71 /// NOTE(mrry): The `data()` method is not virtual for performance reasons.
72 /// It can be called multiple times when the contents of a `Tensor` are
73 /// accessed, and so making it non-virtual allows the body to be inlined.
74 void* data() const { return data_; }
75
76 /// \brief Size (in bytes) of the buffer.
77 virtual size_t size() const = 0;
78
79 /// \brief If this TensorBuffer is sub-buffer of another TensorBuffer,
80 /// returns that TensorBuffer. Otherwise, returns this.
81 virtual TensorBuffer* root_buffer() = 0;
82
83 /// \brief Fills metadata about the allocation into the proto.
84 virtual void FillAllocationDescription(
85 AllocationDescription* proto) const = 0;
86
87 virtual bool GetAllocatedBytes(size_t* out_bytes) const;
88
89 /// \brief Helper method to reinterpret the buffer as an array of `T`.
90 template <typename T>
91 T* base() const {
92 return reinterpret_cast<T*>(data());
93 }
94
95 /// \brief Whether this TensorBuffer owns the underlying memory.
96 virtual bool OwnsMemory() const { return true; }
97
98 /// \brief The type of the underlying memory.
99 virtual AllocatorMemoryType GetMemoryType() const {
100 return AllocatorMemoryType::kUnknown;
101 }
102
103 private:
104 void* const data_;
105};
106
107/// Represents an n-dimensional array of values.
108class Tensor {
109 public:
110 /// \brief Creates a 1-dimensional, 0-element float tensor.
111 ///
112 /// The returned Tensor is not a scalar (shape {}), but is instead
113 /// an empty one-dimensional Tensor (shape {0}, NumElements() ==
114 /// 0). Since it has no elements, it does not need to be assigned a
115 /// value and is initialized by default (IsInitialized() is
116 /// true). If this is undesirable, consider creating a one-element
117 /// scalar which does require initialization:
118 ///
119 /// ```c++
120 ///
121 /// Tensor(DT_FLOAT, TensorShape({}))
122 ///
123 /// ```
124 Tensor();
125
126 /// \brief Creates a Tensor of the given `type` and `shape`. If
127 /// LogMemory::IsEnabled() the allocation is logged as coming from
128 /// an unknown kernel and step. Calling the Tensor constructor
129 /// directly from within an Op is deprecated: use the
130 /// OpKernelConstruction/OpKernelContext allocate_* methods to
131 /// allocate a new tensor, which record the kernel and step.
132 ///
133 /// The underlying buffer is allocated using a `CPUAllocator`.
134 Tensor(DataType type, const TensorShape& shape);
135
136 /// \brief Creates a tensor with the input `type` and `shape`, using
137 /// the allocator `a` to allocate the underlying buffer. If
138 /// LogMemory::IsEnabled() the allocation is logged as coming from
139 /// an unknown kernel and step. Calling the Tensor constructor
140 /// directly from within an Op is deprecated: use the
141 /// OpKernelConstruction/OpKernelContext allocate_* methods to
142 /// allocate a new tensor, which record the kernel and step.
143 ///
144 /// `a` must outlive the lifetime of this Tensor.
145 Tensor(Allocator* a, DataType type, const TensorShape& shape);
146
147 /// \brief Creates a tensor with the input `type` and `shape`, using
148 /// the allocator `a` and the specified "allocation_attr" to
149 /// allocate the underlying buffer. If the kernel and step are known
150 /// allocation_attr.allocation_will_be_logged should be set to true
151 /// and LogMemory::RecordTensorAllocation should be called after the
152 /// tensor is constructed. Calling the Tensor constructor directly
153 /// from within an Op is deprecated: use the
154 /// OpKernelConstruction/OpKernelContext allocate_* methods to
155 /// allocate a new tensor, which record the kernel and step.
156 ///
157 /// `a` must outlive the lifetime of this Tensor.
158 Tensor(Allocator* a, DataType type, const TensorShape& shape,
159 const AllocationAttributes& allocation_attr);
160
161 /// \brief Creates a tensor with the input datatype, shape and buf.
162 ///
163 /// Acquires a ref on buf that belongs to this Tensor.
164 Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
165
166 /// \brief Creates a tensor with the input datatype, shape and buf.
167 ///
168 /// Takes an ownership of the bufffer from the reference counted pointer.
169 Tensor(DataType type, TensorShape shape, core::RefCountPtr<TensorBuffer> buf);
170
171 /// \brief Creates an empty Tensor of the given data type.
172 ///
173 /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
174 /// IsInitialized() returning True. See the Tensor() documentation
175 /// for details.
176 explicit Tensor(DataType type);
177
178 /// \brief Initializes a tensor with the input `type` and `shape`, or returns
179 /// an error and leaves `out_tensor` unmodified. This factory method should be
180 /// used instead of the corresponding constructor if calling code cannot
181 /// validate that the `DataType` is valid and supported.
182 ///
183 /// The underlying buffer is allocated using a `CPUAllocator`.
184 static Status BuildTensor(DataType type, const TensorShape& shape,
185 Tensor* out_tensor);
186
187 private:
188 // A tag type for selecting the `Tensor` constructor overload that creates a
189 // scalar tensor in host memory.
190 struct host_scalar_tag {};
191
192 class HostScalarTensorBufferBase;
193 template <typename T>
194 struct ValueAndTensorBuffer;
195
196 // Creates a tensor with the given scalar `value` in CPU memory.
197 template <typename T>
198 Tensor(T value, host_scalar_tag tag);
199
200 public:
201 // A series of specialized constructors for scalar tensors in host memory.
202 //
203 // NOTE: The `Variant` host-scalar constructor is not defined, because Variant
204 // is implicitly constructible from many different types, and this causes
205 // ambiguities with some compilers.
206 explicit Tensor(float scalar_value)
207 : Tensor(scalar_value, host_scalar_tag{}) {}
208 explicit Tensor(double scalar_value)
209 : Tensor(scalar_value, host_scalar_tag{}) {}
210 explicit Tensor(int32_t scalar_value)
211 : Tensor(scalar_value, host_scalar_tag{}) {}
212 explicit Tensor(uint32 scalar_value)
213 : Tensor(scalar_value, host_scalar_tag{}) {}
214 explicit Tensor(uint16 scalar_value)
215 : Tensor(scalar_value, host_scalar_tag{}) {}
216 explicit Tensor(uint8 scalar_value)
217 : Tensor(scalar_value, host_scalar_tag{}) {}
218 explicit Tensor(int16_t scalar_value)
219 : Tensor(scalar_value, host_scalar_tag{}) {}
220 explicit Tensor(int8_t scalar_value)
221 : Tensor(scalar_value, host_scalar_tag{}) {}
222 explicit Tensor(tstring scalar_value)
223 : Tensor(std::move(scalar_value), host_scalar_tag{}) {}
224 explicit Tensor(complex64 scalar_value)
225 : Tensor(scalar_value, host_scalar_tag{}) {}
226 explicit Tensor(complex128 scalar_value)
227 : Tensor(scalar_value, host_scalar_tag{}) {}
228 explicit Tensor(int64_t scalar_value)
229 : Tensor(scalar_value, host_scalar_tag{}) {}
230 explicit Tensor(uint64 scalar_value)
231 : Tensor(scalar_value, host_scalar_tag{}) {}
232 explicit Tensor(bool scalar_value)
233 : Tensor(scalar_value, host_scalar_tag{}) {}
234 explicit Tensor(qint8 scalar_value)
235 : Tensor(scalar_value, host_scalar_tag{}) {}
236 explicit Tensor(quint8 scalar_value)
237 : Tensor(scalar_value, host_scalar_tag{}) {}
238 explicit Tensor(qint16 scalar_value)
239 : Tensor(scalar_value, host_scalar_tag{}) {}
240 explicit Tensor(quint16 scalar_value)
241 : Tensor(scalar_value, host_scalar_tag{}) {}
242 explicit Tensor(qint32 scalar_value)
243 : Tensor(scalar_value, host_scalar_tag{}) {}
244 explicit Tensor(bfloat16 scalar_value)
245 : Tensor(scalar_value, host_scalar_tag{}) {}
246 explicit Tensor(Eigen::half scalar_value)
247 : Tensor(scalar_value, host_scalar_tag{}) {}
248 explicit Tensor(ResourceHandle scalar_value)
249 : Tensor(std::move(scalar_value), host_scalar_tag{}) {}
250
251 // NOTE: The `const char*` host-scalar constructor is provided as a
252 // convenience because otherwise passing a string literal would surprisingly
253 // construct a DT_BOOL tensor.
254 explicit Tensor(const char* scalar_value)
255 : Tensor(tstring(scalar_value), host_scalar_tag{}) {}
256
257 /// Copy constructor.
258 Tensor(const Tensor& other);
259
260 /// \brief Move constructor. After this call, <other> is safely destructible
261 /// can be assigned to, and IsInitialized() can be called and will return
262 /// false. Other calls on <other> (e.g. shape manipulation) are not valid.
263 Tensor(Tensor&& other);
264
265 // Explicitly delete constructor that take a pointer (except char*)
266 // so that the pointer doesn't get implicitly cast to bool.
267 template <typename T, typename std::enable_if<!std::is_same<T, char>::value,
268 T>::type* = nullptr>
269 explicit Tensor(T* t) = delete;
270
271 ~Tensor();
272
273 /// Returns the data type.
274 DataType dtype() const { return shape_.data_type(); }
275
276 /// Returns the shape of the tensor.
277 const TensorShape& shape() const { return shape_; }
278
279 /// \brief Convenience accessor for the tensor shape.
280 ///
281 /// For all shape accessors, see comments for relevant methods of
282 /// `TensorShape` in `tensor_shape.h`.
283 int dims() const { return shape().dims(); }
284
285 /// Convenience accessor for the tensor shape.
286 int64_t dim_size(int d) const { return shape().dim_size(d); }
287
288 /// Convenience accessor for the tensor shape.
289 int64_t NumElements() const { return shape().num_elements(); }
290
291 bool IsSameSize(const Tensor& b) const {
292 return shape().IsSameSize(b.shape());
293 }
294
295 // True iff the two tensors use the same underlying refcounted storage
296 bool SharesBufferWith(const Tensor& b) const;
297
298 /// \brief If necessary, has this Tensor been initialized?
299 ///
300 /// Zero-element Tensors are always considered initialized, even if they
301 /// have never been assigned to and do not have any memory allocated.
302 bool IsInitialized() const;
303
304 /// Returns the estimated memory usage of this tensor.
305 size_t TotalBytes() const;
306
307 // Returns the size of allocated memory for this tensor.
308 size_t AllocatedBytes() const;
309
310 /// Returns true iff this tensor is aligned.
311 bool IsAligned() const {
312#if EIGEN_MAX_ALIGN_BYTES == 0
313 return true;
314#else
315 void* ptr = base<void>();
316 return dtype() == DT_STRING || NumElements() == 0 ||
317 (reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0);
318#endif
319 }
320
321 /// Assign operator. This tensor shares other's underlying storage.
322 Tensor& operator=(const Tensor& other) {
323 CopyFromInternal(other, other.shape());
324 return *this;
325 }
326
327 /// Move operator. See move constructor for details.
328 Tensor& operator=(Tensor&& other);
329
330 /// \brief Copy the other tensor into this tensor and reshape it.
331 ///
332 /// This tensor shares other's underlying storage. Returns `true`
333 /// iff `other.shape()` has the same number of elements of the given
334 /// `shape`.
335 bool CopyFrom(const Tensor& other,
336 const TensorShape& shape) TF_MUST_USE_RESULT {
337 if (other.NumElements() != shape.num_elements()) return false;
338 CopyFromInternal(other, shape);
339 return true;
340 }
341
342 /// \brief Slice this tensor along the 1st dimension.
343
344 /// I.e., the returned tensor satisfies
345 /// returned[i, ...] == this[dim0_start + i, ...].
346 /// The returned tensor shares the underlying tensor buffer with this
347 /// tensor.
348 ///
349 /// NOTE: The returned tensor may not satisfy the same alignment
350 /// requirement as this tensor depending on the shape. The caller
351 /// must check the returned tensor's alignment before calling certain
352 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
353 ///
354 /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
355 /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
356 ///
357 /// REQUIRES: `dims()` >= 1
358 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
359 Tensor Slice(int64_t dim0_start, int64_t dim0_limit) const;
360
361 /// \brief Select a subslice from this tensor along the 1st dimension.
362 ///
363 /// When fed with an N-dimensional tensor, this method returns a tensor with
364 /// N-1 dimensions, where the returned tensor is a subslice of the input
365 /// tensor along the first dimension. The N-1 dimensions of the returned
366 /// tensor are the last N-1 dimensions of the input tensor.
367 ///
368 /// NOTE: The returned tensor may not satisfy the same alignment
369 /// requirement as this tensor depending on the shape. The caller
370 /// must check the returned tensor's alignment before calling certain
371 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
372 ///
373 /// REQUIRES: `dims()` >= 1
374 /// REQUIRES: `0 <= index < dim_size(0)`
375 Tensor SubSlice(int64_t index) const;
376
377 /// \brief Parse `other` and construct the tensor.
378
379 /// Returns `true` iff the parsing succeeds. If the parsing fails,
380 /// the state of `*this` is unchanged.
381 bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
382 bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
383
384 /// \brief Fills in `proto` with `*this` tensor's content.
385 ///
386 /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
387 /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
388 /// in a compact form.
389 void AsProtoField(TensorProto* proto) const;
390 void AsProtoTensorContent(TensorProto* proto) const;
391
392 /// \brief Return the tensor data as an `Eigen::Tensor` with the type and
393 /// sizes of this `Tensor`.
394 ///
395 /// Use these methods when you know the data type and the number of
396 /// dimensions of the Tensor and you want an `Eigen::Tensor`
397 /// automatically sized to the `Tensor` sizes. The implementation check
398 /// fails if either type or sizes mismatch.
399 ///
400 /// Example:
401 ///
402 /// ```c++
403 ///
404 /// typedef float T;
405 /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
406 /// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
407 /// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
408 /// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
409 /// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
410 /// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
411 ///
412 /// ```
413 template <typename T>
414 typename TTypes<T>::Vec vec() {
415 return tensor<T, 1>();
416 }
417
418 template <typename T>
419 typename TTypes<T>::Matrix matrix() {
420 return tensor<T, 2>();
421 }
422
423 template <typename T, size_t NDIMS>
424 typename TTypes<T, NDIMS>::Tensor tensor() TF_ATTRIBUTE_NOINLINE;
425
426 /// \brief Return the tensor data to an `Eigen::Tensor` with the
427 /// same size but a bitwise cast to the specified dtype `T`.
428 ///
429 /// Using a bitcast is useful for move and copy operations.
430 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
431 template <typename T, size_t NDIMS>
432 typename TTypes<T, NDIMS>::Tensor bit_casted_tensor();
433
434 /// \brief Return the tensor data to an `Eigen::Tensor` with the
435 /// last dimension elements converted into single elements of a larger type.
436 ///
437 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
438 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
439 /// the original element type * num elements in the original last dimension.
440 /// NDIMS should be 1 less than the original number of dimensions.
441 template <typename T, size_t NDIMS>
442 typename TTypes<T, NDIMS>::Tensor reinterpret_last_dimension();
443
444 /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
445 /// specified shape.
446 ///
447 /// These methods allow you to access the data with the dimensions
448 /// and sizes of your choice. You do not need to know the number of
449 /// dimensions of the Tensor to call them. However, they `CHECK` that
450 /// the type matches and the dimensions requested creates an
451 /// `Eigen::Tensor` with the same number of elements as the tensor.
452 ///
453 /// Example:
454 ///
455 /// ```c++
456 ///
457 /// typedef float T;
458 /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
459 /// // 1D Eigen::Tensor, size 60:
460 /// auto flat = my_ten.flat<T>();
461 /// // 2D Eigen::Tensor 12 x 5:
462 /// auto inner = my_ten.flat_inner_dims<T>();
463 /// // 2D Eigen::Tensor 4 x 15:
464 /// auto outer = my_ten.shaped<T, 2>({4, 15});
465 /// // CHECK fails, bad num elements:
466 /// auto outer = my_ten.shaped<T, 2>({4, 8});
467 /// // 3D Eigen::Tensor 6 x 5 x 2:
468 /// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
469 /// // CHECK fails, type mismatch:
470 /// auto bad = my_ten.flat<int32>();
471 ///
472 /// ```
473 template <typename T>
474 typename TTypes<T>::Flat flat();
475
476 template <typename T>
477 typename TTypes<T>::UnalignedFlat unaligned_flat() {
478 return unaligned_shaped<T, 1>({NumElements()});
479 }
480
481 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
482 /// Tensor dimensions but the last NDIMS-1 into the first dimension of the
483 /// result. If NDIMS > dims() then leading dimensions of size 1 will be
484 /// added to make the output rank NDIMS.
485 template <typename T, size_t NDIMS = 2>
486 typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
487
488 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
489 /// Tensor dimensions but the first NDIMS-1 into the last dimension of the
490 /// result. If NDIMS > dims() then trailing dimensions of size 1 will be
491 /// added to make the output rank NDIMS.
492 template <typename T, size_t NDIMS = 2>
493 typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
494
495 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the
496 /// first 'begin' Tensor dimensions into the first dimension of the result and
497 /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last
498 /// dimension of the result. If 'begin' < 0 then the |'begin'| leading
499 /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then
500 /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added.
501 template <typename T, size_t NDIMS = 3>
502 typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64_t begin);
503
504 template <typename T, size_t NDIMS>
505 typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64_t> new_sizes);
506
507 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
508 /// shape specified in `new_sizes` and cast to a new dtype `T`.
509 ///
510 /// Using a bitcast is useful for move and copy operations.
511 /// The allowed bitcast is the only difference from `shaped()`.
512 template <typename T, size_t NDIMS>
513 typename TTypes<T, NDIMS>::Tensor bit_casted_shaped(
514 gtl::ArraySlice<int64_t> new_sizes);
515
516 template <typename T, size_t NDIMS>
517 typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
518 gtl::ArraySlice<int64_t> new_sizes);
519
520 /// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
521 /// `TensorMap<TensorFixedSize<T, 1>>`.
522
523 /// Using `scalar()` allows the compiler to perform optimizations as
524 /// the size of the tensor is known at compile time.
525 template <typename T>
526 typename TTypes<T>::Scalar scalar();
527
528 /// Const versions of all the methods above.
529 template <typename T>
530 typename TTypes<T>::ConstVec vec() const {
531 return tensor<T, 1>();
532 }
533
534 template <typename T>
535 typename TTypes<T>::ConstMatrix matrix() const {
536 return tensor<T, 2>();
537 }
538
539 template <typename T, size_t NDIMS>
540 typename TTypes<T, NDIMS>::ConstTensor tensor() const TF_ATTRIBUTE_NOINLINE;
541
542 /// \brief Return the tensor data to an `Eigen::Tensor` with the
543 /// same size but a bitwise cast to the specified dtype `T`.
544 ///
545 /// Using a bitcast is useful for move and copy operations.
546 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
547 template <typename T, size_t NDIMS>
548 typename TTypes<T, NDIMS>::ConstTensor bit_casted_tensor() const;
549
550 /// \brief Return the tensor data to an `Eigen::Tensor` with the
551 /// last dimension elements converted into single elements of a larger type.
552 ///
553 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
554 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
555 /// the original element type * num elements in the original last dimension.
556 /// NDIMS should be 1 less than the original number of dimensions.
557 template <typename T, size_t NDIMS>
558 typename TTypes<T, NDIMS>::ConstTensor reinterpret_last_dimension() const;
559
560 template <typename T>
561 typename TTypes<T>::ConstFlat flat() const;
562
563 template <typename T>
564 typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
565 return unaligned_shaped<T, 1>({NumElements()});
566 }
567
568 template <typename T, size_t NDIMS>
569 typename TTypes<T, NDIMS>::ConstTensor shaped(
570 gtl::ArraySlice<int64_t> new_sizes) const;
571
572 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
573 /// shape specified in `new_sizes` and cast to a new dtype `T`.
574 ///
575 /// Using a bitcast is useful for move and copy operations.
576 /// The allowed bitcast is the only difference from `shaped()`.
577 template <typename T, size_t NDIMS>
578 typename TTypes<T, NDIMS>::ConstTensor bit_casted_shaped(
579 gtl::ArraySlice<int64_t> new_sizes) const;
580
581 template <typename T, size_t NDIMS>
582 typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
583 gtl::ArraySlice<int64_t> new_sizes) const;
584
585 template <typename T>
586 typename TTypes<T>::ConstScalar scalar() const;
587
588 template <typename T, size_t NDIMS = 2>
589 typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
590
591 template <typename T, size_t NDIMS = 2>
592 typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
593
594 template <typename T, size_t NDIMS = 3>
595 typename TTypes<T, NDIMS>::ConstTensor flat_inner_outer_dims(
596 int64_t begin) const;
597
598 /// Render the first `max_entries` values in `*this` into a string.
599 std::string SummarizeValue(int64_t max_entries, bool print_v2 = false) const;
600
601 /// A human-readable summary of the tensor suitable for debugging.
602 // `num_values` is the number of actual data values in the tensor
603 // included in the message. If the tensor might be resident in
604 // GPU/TPU memory use DeviceSafeDebugString instead.
605 std::string DebugString(int num_values) const;
606 std::string DebugString() const { return DebugString(3); }
607
608 // Variant of DebugString() that should be used for possibly non-CPU tensors.
609 // If the tensor is not resident on CPU, we can't read its values as
610 // DebugString() does.
611 std::string DeviceSafeDebugString() const;
612
613 /// Fill in the `TensorDescription` proto with metadata about the
614 /// tensor that is useful for monitoring and debugging.
615 void FillDescription(TensorDescription* description) const;
616
617 /// \brief Returns a `StringPiece` mapping the current tensor's buffer.
618 ///
619 /// The returned `StringPiece` may point to memory location on devices
620 /// that the CPU cannot address directly.
621 ///
622 /// NOTE: The underlying tensor buffer is refcounted, so the lifetime
623 /// of the contents mapped by the `StringPiece` matches the lifetime of
624 /// the buffer; callers should arrange to make sure the buffer does
625 /// not get destroyed while the `StringPiece` is still used.
626 ///
627 /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
628 StringPiece tensor_data() const;
629 void* data() const;
630
631 /// Copy the other tensor into this tensor, reshape it and reinterpret the
632 /// buffer's datatype. If an ok Status is returned, the two tensors now share
633 /// the same underlying storage.
634 ///
635 /// This call requires that the `other` tensor and the given type and shape
636 /// are "compatible" (i.e. they occupy the same number of bytes).
637 ///
638 /// Specifically:
639 ///
640 /// shape.num_elements() * DataTypeSize(type)
641 ///
642 /// must equal
643 ///
644 /// other.num_elements() * DataTypeSize(other.dtype())
645 ///
646 /// In addition, this function requires:
647 /// * DataTypeSize(other.dtype()) != 0
648 /// * DataTypeSize(type) != 0
649 ///
650 /// If any of the requirements are not met, errors::InvalidArgument is
651 /// returned.
652 Status BitcastFrom(const Tensor& other, DataType dtype,
653 const TensorShape& shape);
654
655 /// Like BitcastFrom, but CHECK fails if any preconditions are not met.
656 ///
657 /// Deprecated. Use BitcastFrom instead and check the returned Status.
658 void UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
659 const TensorShape& shape) {
660 TF_CHECK_OK(BitcastFrom(other, dtype, shape));
661 }
662
663 // Returns true if the refcount on buf_ and any possible underlying root
664 // buffer is one.
665 bool RefCountIsOne() const;
666
667 // Returns the type of the underlying memory.
668 AllocatorMemoryType GetMemoryType() const { return buf_->GetMemoryType(); }
669
670 private:
671 void CheckType(DataType expected_dtype) const;
672 void CheckTypeAndIsAligned(DataType expected_dtype) const;
673 void CheckIsAlignedAndSingleElement() const;
674 void set_dtype(DataType t) { shape_.set_data_type(t); }
675
676 // TensorShape's InlineVector.
677 static gtl::InlinedVector<int64_t, 4> ComputeFlatInnerDims(
678 gtl::ArraySlice<int64_t> orig, int64_t num_out_dims);
679 static gtl::InlinedVector<int64_t, 4> ComputeFlatOuterDims(
680 gtl::ArraySlice<int64_t> orig, int64_t num_out_dims);
681
682 TensorShape shape_;
683 TensorBuffer* buf_;
684
685 friend class DMAHelper; // For access to buf_.
686 friend class TensorCApi; // For access to buf_.
687 friend class TensorCord; // For access to buf_.
688 friend class TensorReference; // For access to buf_.
689 friend class VariableOp; // For access to set_shape.
690 friend class AutoReloadVariableOp; // For access to set_shape.
691 friend class TensorTestHelper; // For access to set_shape.
692 friend class TensorInterface; // For access to set_shape.
693 friend class CastOpBase; // For access to set_dtype.
694 friend class ScopedAllocator; // For access to buf_.
695 friend Status batch_util::CopyElementToSlice(
696 Tensor element, Tensor* parent,
697 int64_t index); // For access to base<T>().
698 friend Status batch_util::CopySliceToElement(
699 const Tensor& parent, Tensor* element,
700 int64_t index); // For access to base<T>().
701 friend Status batch_util::MaybeMoveSliceToElement(
702 Tensor* parent, Tensor* element,
703 int64_t index); // For access to base<T>().
704 friend Status batch_util::CopyContiguousSlices(
705 const Tensor& src, int64_t src_offset, int64_t dst_offset,
706 int64_t num_slices,
707 Tensor* dst); // For access to base<T>().
708
709 bool CanUseDMA() const;
710
711 // Only needed by variable op to set the shape of an uninitialized
712 // Tensor.
713 // TODO: Remove this when we have a better story for detecting
714 // uninitialized tensors.
715 void set_shape(const TensorShape& shape) {
716 DataType dt = dtype();
717 shape_ = shape;
718 set_dtype(dt);
719 }
720
721 inline void CopyFromInternal(const Tensor& other, const TensorShape& shape) {
722 DCHECK_EQ(shape.num_elements(), other.NumElements());
723 // Data type will be overwritten if this == &other, since dtype is part of
724 // shape.
725 DataType other_dtype = other.dtype();
726 shape_ = shape;
727 set_dtype(other_dtype);
728 if (buf_ != other.buf_) {
729 if (buf_) buf_->Unref();
730 buf_ = other.buf_;
731 if (buf_) buf_->Ref();
732 }
733 }
734
735 template <typename T>
736 T* base() const;
737
738 template <size_t NDIMS>
739 void FillDimsAndValidateCompatibleShape(
740 gtl::ArraySlice<int64_t> new_sizes,
741 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
742
743 template <typename T, size_t NDIMS>
744 void FillDimsAndValidateCompatibleShape(
745 gtl::ArraySlice<int64_t> new_sizes,
746 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
747};
748
749// Implementation details
750
751// START_SKIP_DOXYGEN
752
753template <typename T>
754T* Tensor::base() const {
755 return buf_ == nullptr ? nullptr : buf_->base<T>();
756}
757
758// This routine is defined out of line for code-space savings
759template <typename T, size_t NDIMS>
760typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
761 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
762 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
763 shape().AsEigenDSizes<NDIMS>());
764}
765
766// This routine is defined out of line for code-space savings
767template <typename T, size_t NDIMS>
768typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
769 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
770 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
771 shape().AsEigenDSizes<NDIMS>());
772}
773
774template <typename T, size_t NDIMS>
775typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_tensor() {
776 CHECK(IsAligned());
777 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
778 shape().AsEigenDSizes<NDIMS>());
779}
780
781template <typename T, size_t NDIMS>
782typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_tensor() const {
783 CHECK(IsAligned());
784 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
785 shape().AsEigenDSizes<NDIMS>());
786}
787
788template <typename T, size_t NDIMS>
789typename TTypes<T, NDIMS>::Tensor Tensor::reinterpret_last_dimension() {
790 if (NDIMS == dims()) {
791 return tensor<T, NDIMS>();
792 }
793 CHECK(IsAligned());
794 CHECK_EQ(static_cast<int>(NDIMS), dims() - 1);
795 CHECK_EQ(static_cast<int64_t>(sizeof(T)),
796 shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
797 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
798 for (int d = 0; d < NDIMS; ++d) {
799 dims[d] = shape_.dim_sizes()[d];
800 }
801 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
802}
803
804template <typename T, size_t NDIMS>
805typename TTypes<T, NDIMS>::ConstTensor Tensor::reinterpret_last_dimension()
806 const {
807 if (NDIMS == dims()) {
808 return tensor<T, NDIMS>();
809 }
810 CHECK(IsAligned());
811 CHECK_EQ(static_cast<int>(NDIMS), dims() - 1);
812 CHECK_EQ(static_cast<int64_t>(sizeof(T)),
813 shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
814 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
815 for (int d = 0; d < NDIMS; ++d) {
816 dims[d] = shape_.dim_sizes()[d];
817 }
818 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), dims);
819}
820
821template <size_t NDIMS>
822void Tensor::FillDimsAndValidateCompatibleShape(
823 gtl::ArraySlice<int64_t> new_sizes,
824 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
825 CHECK_EQ(NDIMS, new_sizes.size());
826 int64_t new_num_elements = 1;
827 for (size_t d = 0; d < NDIMS; d++) {
828 new_num_elements *= new_sizes[d];
829 (*dims)[d] = new_sizes[d];
830 }
831 CHECK_EQ(new_num_elements, NumElements());
832}
833
834template <typename T, size_t NDIMS>
835void Tensor::FillDimsAndValidateCompatibleShape(
836 gtl::ArraySlice<int64_t> new_sizes,
837 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
838 CHECK_EQ(NDIMS, new_sizes.size());
839 int64_t new_num_elements = 1;
840 for (size_t d = 0; d < NDIMS; d++) {
841 new_num_elements *= new_sizes[d];
842 (*dims)[d] = new_sizes[d];
843 }
844 const int element_size = DataTypeSize(BaseType(dtype()));
845 if (element_size > 0) {
846 CHECK_EQ(new_num_elements * static_cast<int64_t>(sizeof(T)),
847 NumElements() * element_size);
848 } else {
849 // DataTypeSize() returns 0 for some data types. In this case, assume that T
850 // has the same size as the buffer type.
851 // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD
852 // types, then we should check DataTypeToEnum<T>::v() == dtype(). Or simply
853 // check if `element_size > 0` to err when bit cast is attempted on Tensor
854 // of unknown data type size.
855 CHECK_EQ(new_num_elements, NumElements());
856 }
857}
858
859template <typename T>
860typename TTypes<T>::Flat Tensor::flat() {
861 // Equivalent to 'return shaped<T, 1>({NumElements()});'
862 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
863 Eigen::array<Eigen::DenseIndex, 1> dims;
864 dims[0] = NumElements();
865 return typename TTypes<T, 1>::Tensor(base<T>(), dims);
866}
867
868template <typename T>
869typename TTypes<T>::ConstFlat Tensor::flat() const {
870 // Equuivalent to 'return shaped<T, 1>({NumElements()});'
871 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
872 Eigen::array<Eigen::DenseIndex, 1> dims;
873 dims[0] = NumElements();
874 return typename TTypes<T, 1>::ConstTensor(base<T>(), dims);
875}
876
877template <typename T, size_t NDIMS>
878typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
879 gtl::ArraySlice<int64_t> new_sizes) {
880 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
881 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
882 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
883 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
884}
885
886template <typename T, size_t NDIMS>
887typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_shaped(
888 gtl::ArraySlice<int64_t> new_sizes) {
889 CHECK(IsAligned());
890 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
891 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
892 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
893}
894
895template <typename T, size_t NDIMS>
896typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
897 gtl::ArraySlice<int64_t> new_sizes) {
898 CheckType(DataTypeToEnum<T>::v());
899 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
900 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
901 return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
902}
903
904template <typename T, size_t NDIMS>
905typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
906 gtl::ArraySlice<int64_t> new_sizes) const {
907 CheckType(DataTypeToEnum<T>::v());
908 CHECK(IsAligned()) << "ptr = " << base<void>();
909 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
910 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
911 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
912}
913
914template <typename T, size_t NDIMS>
915typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_shaped(
916 gtl::ArraySlice<int64_t> new_sizes) const {
917 CHECK(IsAligned());
918 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
919 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
920 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
921}
922
923template <typename T, size_t NDIMS>
924typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
925 gtl::ArraySlice<int64_t> new_sizes) const {
926 CheckType(DataTypeToEnum<T>::v());
927 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
928 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
929 return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
930}
931
932template <typename T>
933typename TTypes<T>::Scalar Tensor::scalar() {
934 static_assert(
935 !std::is_same<T, std::string>::value,
936 "std::string is no longer a scalar type, use tensorflow::tstring");
937 CheckIsAlignedAndSingleElement();
938 return typename TTypes<T>::Scalar(base<T>());
939}
940
941template <typename T>
942typename TTypes<T>::ConstScalar Tensor::scalar() const {
943 static_assert(
944 !std::is_same<T, std::string>::value,
945 "std::string is no longer a scalar type, use tensorflow::tstring");
946 CheckIsAlignedAndSingleElement();
947 return typename TTypes<T>::ConstScalar(base<T>());
948}
949
950template <typename T, size_t NDIMS>
951typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
952 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
953}
954
955template <typename T, size_t NDIMS>
956typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
957 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
958}
959
960template <typename T, size_t NDIMS>
961typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64_t begin) {
962 gtl::InlinedVector<int64_t, 4> flat_outer =
963 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
964 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
965}
966
967template <typename T, size_t NDIMS>
968typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
969 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
970}
971
972template <typename T, size_t NDIMS>
973typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
974 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
975}
976
977template <typename T, size_t NDIMS>
978typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_outer_dims(
979 int64_t begin) const {
980 gtl::InlinedVector<int64_t, 4> flat_outer =
981 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
982 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
983}
984
985inline Tensor::Tensor(const Tensor& other)
986 : shape_(other.shape()), buf_(other.buf_) {
987 if (buf_) buf_->Ref();
988}
989
990inline Tensor::Tensor(Tensor&& other)
991 : shape_(std::move(other.shape_)), buf_(other.buf_) {
992 other.buf_ = nullptr;
993}
994
995class Tensor::HostScalarTensorBufferBase : public TensorBuffer {
996 public:
997 using TensorBuffer::TensorBuffer;
998 bool GetAllocatedBytes(size_t* out_bytes) const final;
999 void FillAllocationDescription(AllocationDescription* proto) const final;
1000};
1001
1002// A packed representation for a single scalar value of type `T`, and a
1003// `TensorBuffer` implementation that describes (and manages the lifetime of)
1004// that value.
1005template <typename T>
1006struct Tensor::ValueAndTensorBuffer {
1007 class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase {
1008 public:
1009 explicit HostScalarTensorBuffer(void* data)
1010 : HostScalarTensorBufferBase(data) {}
1011 size_t size() const final { return sizeof(T); }
1012 TensorBuffer* root_buffer() final { return this; }
1013
1014 // Override `operator delete` so that calling `delete this` in
1015 // `core::Refcounted::Unref()` for an object of this type will free
1016 // the enclosing `ValueAndTensorBuffer` for the tensor buffer.
1017 //
1018 // NOTE(mrry): The definition of this method must be outside the class
1019 // definition in order to satisfy some compilers.
1020 static void operator delete(void* ptr);
1021
1022 static void operator delete(void*, void*) {
1023 // Some compilers require an overridden class-specific deallocation
1024 // function, which will be called if placement `new` throws an
1025 // exception.
1026 }
1027
1028 private:
1029 ~HostScalarTensorBuffer() override { static_cast<T*>(data())->~T(); }
1030 };
1031
1032 T value;
1033 HostScalarTensorBuffer tensor_buffer;
1034};
1035
1036/* static */
1037template <typename T>
1038void Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer::operator delete(
1039 void* ptr) {
1040 // Use a dummy object to compute to offset of
1041 // `ValueAndTensorBuffer::tensor_buffer`, because `offsetof()` is not
1042 // necessarily defined on this non-POD type (until C++17).
1043 //
1044 // NOTE(mrry): Using `sizeof(Tensor::ValueAndTensorBuffer<T>)` here requires
1045 // us to define this method outside the class definition, so that it is not
1046 // considered an incomplete type.
1047 typename std::aligned_storage<sizeof(Tensor::ValueAndTensorBuffer<T>),
1048 alignof(Tensor::ValueAndTensorBuffer<T>)>::type
1049 dummy_storage_;
1050 Tensor::ValueAndTensorBuffer<T>* dummy_object =
1051 reinterpret_cast<Tensor::ValueAndTensorBuffer<T>*>(&dummy_storage_);
1052 intptr_t offset = reinterpret_cast<intptr_t>(&dummy_object->tensor_buffer) -
1053 reinterpret_cast<intptr_t>(dummy_object);
1054
1055 port::AlignedFree(static_cast<char*>(ptr) - offset);
1056}
1057
1058template <typename T>
1059Tensor::Tensor(T value, host_scalar_tag tag) {
1060 auto* value_and_buf = static_cast<Tensor::ValueAndTensorBuffer<T>*>(
1061 port::AlignedMalloc(sizeof(typename Tensor::ValueAndTensorBuffer<T>),
1062 EIGEN_MAX_ALIGN_BYTES));
1063 new (&value_and_buf->value) T(std::move(value));
1064 new (&value_and_buf->tensor_buffer)
1065 typename Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer(
1066 value_and_buf);
1067 buf_ = &value_and_buf->tensor_buffer;
1068 set_dtype(DataTypeToEnum<T>::value);
1069}
1070
1071inline Tensor& Tensor::operator=(Tensor&& other) {
1072 // Avoid self-assignment, since we might destroy our underlying buffer.
1073 if (&other != this) {
1074 shape_ = std::move(other.shape_);
1075 if (buf_) buf_->Unref();
1076 buf_ = other.buf_;
1077 other.buf_ = nullptr;
1078 }
1079 return *this;
1080}
1081
1082// END_SKIP_DOXYGEN
1083
1084} // namespace tensorflow
1085
1086#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
1087