1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implementation notes:
17//
18// Tensor.cc uses a few templated classes and structs to facilitate
19// implementation of the Tensor class.
20//
21// * Buffer<T>: provides the implementation for a typed array T[n].
22// The array is allocated by the given allocator. It runs T's
23// default constructors and destructors when T is not a simple type
24// (e.g., string.), and skips them otherwise.
25//
26// * Helper<T>: provides various routines given type T. The routines
27// includes running the constructor and destructor of T[], encoding
28// an decoding T[] into/from a Cord, etc.
29
30#include "tensorflow/core/framework/tensor.h"
31
32#include <utility>
33
34#include "absl/strings/escaping.h"
35#include "tensorflow/core/framework/allocation_description.pb.h"
36#include "tensorflow/core/framework/log_memory.h"
37#include "tensorflow/core/framework/resource_handle.h"
38#include "tensorflow/core/framework/resource_handle.pb.h"
39#include "tensorflow/core/framework/tensor.pb.h"
40#include "tensorflow/core/framework/tensor_description.pb.h"
41#include "tensorflow/core/framework/type_traits.h"
42#include "tensorflow/core/framework/typed_allocator.h"
43#include "tensorflow/core/framework/types.h"
44#include "tensorflow/core/framework/types.pb.h"
45#include "tensorflow/core/framework/variant.h"
46#include "tensorflow/core/framework/variant_encode_decode.h"
47#include "tensorflow/core/framework/variant_op_registry.h"
48#include "tensorflow/core/framework/variant_tensor_data.h"
49#include "tensorflow/core/lib/core/coding.h"
50#include "tensorflow/core/lib/core/errors.h"
51#include "tensorflow/core/lib/core/status.h"
52#include "tensorflow/core/lib/gtl/inlined_vector.h"
53#include "tensorflow/core/lib/strings/str_util.h"
54#include "tensorflow/core/lib/strings/strcat.h"
55#include "tensorflow/core/platform/errors.h"
56#include "tensorflow/core/platform/logging.h"
57#include "tensorflow/core/platform/macros.h"
58#include "tensorflow/core/platform/protobuf.h"
59#include "tensorflow/core/platform/tensor_coding.h"
60#include "tensorflow/core/platform/types.h"
61
62namespace tensorflow {
63
64// Allow Tensors to be stored inside Variants with automatic
65// encoding/decoding when those Variants are themselves being decoded
66// in a Tensor's FromProto.
67//
68// NOTE(mrry): The corresponding "copy function" registrations can be found in
69// ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
70// code).
71REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
72
73bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
74 AllocationDescription allocation_description;
75 FillAllocationDescription(&allocation_description);
76 if (allocation_description.allocated_bytes() > 0) {
77 *out_bytes = allocation_description.allocated_bytes();
78 return true;
79 } else {
80 return false;
81 }
82}
83
84namespace {
85
86// An un-templated base class for Buffer.
87class BufferBase : public TensorBuffer {
88 public:
89 explicit BufferBase(Allocator* alloc, void* data_ptr)
90 : TensorBuffer(data_ptr), alloc_(alloc) {}
91
92 TensorBuffer* root_buffer() override { return this; }
93
94 bool GetAllocatedBytes(size_t* out_bytes) const override {
95 if (alloc_->TracksAllocationSizes()) {
96 *out_bytes = alloc_->AllocatedSize(data());
97 return *out_bytes > 0;
98 } else {
99 return false;
100 }
101 }
102
103 void FillAllocationDescription(AllocationDescription* proto) const override {
104 void* data_ptr = data();
105 int64_t rb = size();
106 proto->set_requested_bytes(rb);
107 proto->set_allocator_name(alloc_->Name());
108 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
109 if (alloc_->TracksAllocationSizes()) {
110 int64_t ab = alloc_->AllocatedSize(data_ptr);
111 proto->set_allocated_bytes(ab);
112 int64_t id = alloc_->AllocationId(data_ptr);
113 if (id > 0) {
114 proto->set_allocation_id(id);
115 }
116 if (RefCountIsOne()) {
117 proto->set_has_single_reference(true);
118 }
119 }
120 }
121
122 // Returns the type of the underlying memory.
123 AllocatorMemoryType GetMemoryType() const override {
124 return alloc_->GetMemoryType();
125 }
126
127 protected:
128 void RecordDeallocation() {
129 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
130 alloc_->Name());
131 }
132
133 Allocator* const alloc_;
134};
135
136// Typed ref-counted buffer: T[n].
137template <typename T>
138class Buffer : public BufferBase {
139 public:
140 Buffer(Allocator* a, int64_t n);
141 Buffer(Allocator* a, int64_t n, const AllocationAttributes& allocation_attr);
142
143 size_t size() const override { return sizeof(T) * elem_; }
144
145 private:
146 int64_t elem_;
147
148 ~Buffer() override;
149
150 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
151};
152
153void LogUnexpectedSize(int64_t actual, int64_t expected) {
154 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
155}
156
157bool MemoryLoggingEnabled() {
158 static bool memory_logging_enabled = LogMemory::IsEnabled();
159 return memory_logging_enabled;
160}
161
162// A set of helper functions depending on T.
163template <typename T>
164struct Helper {
165 // By default, we assume T is a simple type (float, int32, etc.)
166 static_assert(is_simple_type<T>::value, "T is not a simple type.");
167 typedef protobuf::RepeatedField<T> RepeatedFieldType;
168
169 // Encoder of simple type T to a string. We do a copy.
170 template <typename Destination>
171 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
172 DCHECK_EQ(in->size(), sizeof(T) * n);
173 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
174 out);
175 }
176
177 // Decoder of simple type T. Copy the bytes from "in" into the
178 // tensor buffer.
179 template <typename Source>
180 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
181 if (in.size() != sizeof(T) * n) {
182 LogUnexpectedSize(in.size(), sizeof(T) * n);
183 return nullptr;
184 }
185 Buffer<T>* buf = new Buffer<T>(a, n);
186 char* data = buf->template base<char>();
187 if (data == nullptr) {
188 buf->Unref();
189 return nullptr;
190 }
191 port::CopyToArray(in, data);
192 return buf;
193 }
194
195 // Memory usage.
196 static int64_t TotalBytes(TensorBuffer* in, int64_t n) {
197 DCHECK_EQ(in->size(), sizeof(T) * n);
198 return in->size();
199 }
200};
201
202// Helper specialization for string (the only non-simple type we
203// support).
204template <>
205struct Helper<tstring> {
206 // Proto message uses RepeatedFieldType to hold repeated T.
207 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
208
209 // Encodes "n" elements of type string stored in "in" into Cord
210 // "out", which is usually the TensorProto::tensor_content.
211 template <typename Destination>
212 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
213 port::EncodeStringList(in->base<const tstring>(), n, out);
214 }
215
216 // Decodes "n" elements of type string from "in" and constructs a
217 // buffer out of it. Returns nullptr if the decoding fails. "in" is
218 // usually the TensorProto::tensor_content.
219 template <typename Source>
220 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
221 Buffer<tstring>* buf = new Buffer<tstring>(a, n);
222 tstring* strings = buf->template base<tstring>();
223 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
224 buf->Unref();
225 return nullptr;
226 }
227 return buf;
228 }
229
230 // Returns the estimated memory usage of "n" elements of type T
231 // stored in buffer "in".
232 static int64_t TotalBytes(TensorBuffer* in, int n) {
233 int64_t tot = in->size();
234 DCHECK_EQ(tot, sizeof(tstring) * n);
235 const tstring* p = in->base<const tstring>();
236 for (int i = 0; i < n; ++i, ++p) tot += p->size();
237 return tot;
238 }
239};
240
241template <>
242struct Helper<ResourceHandle> {
243 // Proto message uses RepeatedFieldType to hold repeated T.
244 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
245
246 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
247 // "out", which is usually the TensorProto::tensor_content.
248 template <typename Destination>
249 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
250 EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
251 port::NewStringListEncoder(out));
252 }
253
254 // Decodes "n" elements of type string from "in" and constructs a
255 // buffer out of it. Returns nullptr if the decoding fails. "in" is
256 // usually the TensorProto::tensor_content.
257 template <typename Source>
258 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
259 auto* buf = new Buffer<ResourceHandle>(a, n);
260 ResourceHandle* ps = buf->template base<ResourceHandle>();
261 if (ps == nullptr ||
262 !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
263 buf->Unref();
264 return nullptr;
265 }
266 return buf;
267 }
268
269 // Returns the estimated memory usage of "n" elements of type T
270 // stored in buffer "in".
271 static int64_t TotalBytes(TensorBuffer* in, int n) {
272 return n * sizeof(ResourceHandle);
273 }
274};
275
276template <>
277struct Helper<Variant> {
278 // Encodes "n" elements of type Variant stored in "in" into destination
279 // "out", which is usually the TensorProto::tensor_content.
280 template <typename Destination>
281 static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
282 EncodeVariantList(in->base<const Variant>(), n,
283 port::NewStringListEncoder(out));
284 }
285
286 // Decodes "n" elements of type Variant from "in" and constructs a
287 // buffer out of it. Returns nullptr if the decoding fails. "in" is
288 // usually the TensorProto::tensor_content.
289 template <typename Source>
290 static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
291 auto* buf = new Buffer<Variant>(a, n);
292 Variant* ps = buf->template base<Variant>();
293 if (ps == nullptr ||
294 !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
295 buf->Unref();
296 return nullptr;
297 }
298 return buf;
299 }
300
301 // Returns the estimated memory usage of "n" elements of type T
302 // stored in buffer "in".
303 static int64_t TotalBytes(TensorBuffer* in, int n) {
304 return n * sizeof(Variant);
305 }
306};
307
308template <typename T>
309struct ProtoHelper {};
310
311// For a C++ type "T" (float, double, int32, etc.), the repeated field
312// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
313// int32, string, etc) in the TensorProto is used for serializing the
314// tensor of type "T".
315#define PROTO_TRAITS(T, F, N) \
316 template <> \
317 struct ProtoHelper<T> { \
318 typedef Helper<F>::RepeatedFieldType FieldType; \
319 static FieldType::const_iterator Begin(const TensorProto& proto) { \
320 return proto.N##_val().begin(); \
321 } \
322 static size_t NumElements(const TensorProto& proto) { \
323 return proto.N##_val().size(); \
324 } \
325 static void Fill(const T* data, size_t n, TensorProto* proto) { \
326 typename ProtoHelper<T>::FieldType copy(data, data + n); \
327 proto->mutable_##N##_val()->Swap(&copy); \
328 } \
329 };
330PROTO_TRAITS(float, float, float);
331PROTO_TRAITS(double, double, double);
332PROTO_TRAITS(int32, int32, int);
333PROTO_TRAITS(uint8, int32, int);
334PROTO_TRAITS(uint16, int32, int);
335PROTO_TRAITS(uint32, uint32, uint32);
336PROTO_TRAITS(int16, int32, int);
337PROTO_TRAITS(int8, int32, int);
338PROTO_TRAITS(bool, bool, bool);
339PROTO_TRAITS(tstring, tstring, string);
340PROTO_TRAITS(qint8, int32, int);
341PROTO_TRAITS(quint8, int32, int);
342PROTO_TRAITS(qint16, int32, int);
343PROTO_TRAITS(quint16, int32, int);
344#undef PROTO_TRAITS
345
346template <>
347struct ProtoHelper<int64_t> {
348 static protobuf::RepeatedField<int64_t>::const_iterator Begin(
349 const TensorProto& proto) {
350 return proto.int64_val().begin();
351 }
352 static size_t NumElements(const TensorProto& proto) {
353 return proto.int64_val().size();
354 }
355 static void Fill(const int64_t* data, size_t n, TensorProto* proto) {
356 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
357 proto->mutable_int64_val()->Swap(&copy);
358 }
359};
360
361template <>
362struct ProtoHelper<uint64> {
363 static protobuf::RepeatedField<uint64_t>::const_iterator Begin(
364 const TensorProto& proto) {
365 return proto.uint64_val().begin();
366 }
367 static size_t NumElements(const TensorProto& proto) {
368 return proto.uint64_val().size();
369 }
370 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
371 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
372 proto->mutable_uint64_val()->Swap(&copy);
373 }
374};
375
376template <>
377struct ProtoHelper<ResourceHandle> {
378 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
379 const TensorProto& proto) {
380 return proto.resource_handle_val().begin();
381 }
382 static size_t NumElements(const TensorProto& proto) {
383 return proto.resource_handle_val().size();
384 }
385 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
386 auto* handles = proto->mutable_resource_handle_val();
387 handles->Clear();
388 for (size_t i = 0; i < n; i++) {
389 data[i].AsProto(handles->Add());
390 }
391 }
392};
393
394template <>
395struct ProtoHelper<Variant> {
396 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
397 Begin(const TensorProto& proto) {
398 return proto.variant_val().begin();
399 }
400 static size_t NumElements(const TensorProto& proto) {
401 return proto.variant_val().size();
402 }
403 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
404 auto* variant_values = proto->mutable_variant_val();
405 variant_values->Clear();
406 for (size_t i = 0; i < n; ++i) {
407 VariantTensorData tmp;
408 data[i].Encode(&tmp);
409 tmp.ToProto(variant_values->Add());
410 }
411 }
412};
413
414template <>
415struct ProtoHelper<complex64> {
416 typedef Helper<float>::RepeatedFieldType FieldType;
417 static const complex64* Begin(const TensorProto& proto) {
418 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
419 }
420 static size_t NumElements(const TensorProto& proto) {
421 return proto.scomplex_val().size() / 2;
422 }
423 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
424 const float* p = reinterpret_cast<const float*>(data);
425 FieldType copy(p, p + n * 2);
426 proto->mutable_scomplex_val()->Swap(&copy);
427 }
428};
429
430template <>
431struct ProtoHelper<complex128> {
432 typedef Helper<double>::RepeatedFieldType FieldType;
433 static const complex128* Begin(const TensorProto& proto) {
434 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
435 }
436 static size_t NumElements(const TensorProto& proto) {
437 return proto.dcomplex_val().size() / 2;
438 }
439 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
440 const double* p = reinterpret_cast<const double*>(data);
441 FieldType copy(p, p + n * 2);
442 proto->mutable_dcomplex_val()->Swap(&copy);
443 }
444};
445
446template <>
447struct ProtoHelper<qint32> {
448 typedef Helper<int32>::RepeatedFieldType FieldType;
449 static const qint32* Begin(const TensorProto& proto) {
450 return reinterpret_cast<const qint32*>(proto.int_val().data());
451 }
452 static size_t NumElements(const TensorProto& proto) {
453 return proto.int_val().size();
454 }
455 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
456 const int32* p = reinterpret_cast<const int32*>(data);
457 FieldType copy(p, p + n);
458 proto->mutable_int_val()->Swap(&copy);
459 }
460};
461
462template <>
463struct ProtoHelper<bfloat16> {
464 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
465 proto->mutable_half_val()->Reserve(n);
466 for (size_t i = 0; i < n; ++i) {
467 proto->mutable_half_val()->AddAlreadyReserved(
468 Eigen::numext::bit_cast<uint16>(data[i]));
469 }
470 }
471};
472
473template <>
474struct ProtoHelper<Eigen::half> {
475 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
476 proto->mutable_half_val()->Reserve(n);
477 for (size_t i = 0; i < n; ++i) {
478 proto->mutable_half_val()->AddAlreadyReserved(
479 Eigen::numext::bit_cast<uint16>(data[i]));
480 }
481 }
482};
483
484template <typename T>
485Buffer<T>::Buffer(Allocator* a, int64_t n)
486 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
487 elem_(n) {}
488
489template <typename T>
490Buffer<T>::Buffer(Allocator* a, int64_t n,
491 const AllocationAttributes& allocation_attr)
492 : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
493 elem_(n) {}
494
495template <typename T>
496Buffer<T>::~Buffer() {
497 if (data()) {
498 if (MemoryLoggingEnabled()) {
499 RecordDeallocation();
500 }
501 TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
502 }
503}
504
505// Allocates a T[n] buffer. Fills in the buffer with repeated values
506// in "in". If "in" has less values than "n", fills the rest of T[n]
507// with the last value. If "in" has no values, fills T[n] with the
508// default value for T.
509//
510// This routine is using the typed fields (float_val, etc.) in the
511// tensor proto as opposed to the untyped binary representation
512// (tensor_content). This is used when we expect the TensorProto is
513// used by a client program which may not know how to encode a tensor
514// in the compact binary representation.
515template <typename T>
516TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) {
517 CHECK_GT(n, 0);
518 Buffer<T>* buf = new Buffer<T>(a, n);
519 T* data = buf->template base<T>();
520 if (data == nullptr) {
521 buf->Unref();
522 return nullptr;
523 }
524
525 const int64_t in_n = ProtoHelper<T>::NumElements(in);
526 if (in_n <= 0) {
527 std::fill_n(data, n, T());
528 } else {
529 auto begin = ProtoHelper<T>::Begin(in);
530 if (n <= in_n) {
531 std::copy_n(begin, n, data);
532 } else {
533 std::copy_n(begin, in_n, data);
534 if (std::is_trivially_copyable<T>::value) {
535 const T last = *(data + in_n - 1);
536 std::fill_n(data + in_n, n - in_n, last);
537 } else {
538 const T& last = *(data + in_n - 1);
539 std::fill_n(data + in_n, n - in_n, last);
540 }
541 }
542 }
543
544 return buf;
545}
546
547// Separate implementation for `ResourceHandle` to handle the case when the
548// proto for the resource is invalid. See `resource_handle.h` constructor and
549// static factory builder.
550template <>
551TensorBuffer* FromProtoField<ResourceHandle>(Allocator* a,
552 const TensorProto& in, int64_t n) {
553 CHECK_GT(n, 0);
554 Buffer<ResourceHandle>* buf = new Buffer<ResourceHandle>(a, n);
555 ResourceHandle* data = buf->template base<ResourceHandle>();
556 if (data == nullptr) {
557 buf->Unref();
558 return nullptr;
559 }
560 const int64_t in_n = ProtoHelper<ResourceHandle>::NumElements(in);
561 if (in_n <= 0) {
562 std::fill_n(data, n, ResourceHandle());
563 } else {
564 // If tensor shape says we have n < in_n elements in the output tensor
565 // then make sure to only decode the first n out of the in_n elements in the
566 // in tensors. In all other cases, we decode all in_n elements of in and set
567 // the remaining elements up to n to be the default ResourceHandle() value.
568 const int64_t real_n = n < in_n ? n : in_n;
569 for (int64_t i = 0; i < real_n; ++i) {
570 Status s = ResourceHandle::BuildResourceHandle(in.resource_handle_val(i),
571 &data[i]);
572 if (!s.ok()) {
573 LOG(ERROR) << "Could not decode resource handle from proto \""
574 << in.resource_handle_val(i).ShortDebugString()
575 << "\", returned status: " << s.ToString();
576 buf->Unref();
577 return nullptr;
578 }
579 }
580 for (int64_t i = in_n; i < n; ++i) {
581 data[i] = ResourceHandle();
582 }
583 }
584 return buf;
585}
586
587template <>
588TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
589 int64_t n) {
590 CHECK_GT(n, 0);
591 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
592 Variant* data = buf->template base<Variant>();
593 if (data == nullptr) {
594 buf->Unref();
595 return nullptr;
596 }
597 const int64_t in_n = ProtoHelper<Variant>::NumElements(in);
598 if (in_n <= 0) {
599 std::fill_n(data, n, Variant());
600 } else {
601 // If tensor shape says we have n < in_n elements in the output tensor
602 // then make sure to only decode the first n out of the in_n elements in the
603 // in tensors. In all other cases, we decode all in_n elements of in and set
604 // the remaining elements up to n to be the default Variant() value.
605 const int64_t real_n = n < in_n ? n : in_n;
606 for (int64_t i = 0; i < real_n; ++i) {
607 data[i] = in.variant_val(i);
608 if (!DecodeUnaryVariant(&data[i])) {
609 LOG(ERROR) << "Could not decode variant with type_name: \""
610 << data[i].TypeName()
611 << "\". Perhaps you forgot to register a "
612 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
613 buf->Unref();
614 return nullptr;
615 }
616 }
617 for (int64_t i = in_n; i < n; ++i) {
618 data[i] = Variant();
619 }
620 }
621 return buf;
622}
623
624// fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
625// identical to uint16 but with data stored in half_val instead of int_val (ie.,
626// we don't use ProtoHelper<uint16>).
627template <>
628TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
629 int64_t n) {
630 CHECK_GT(n, 0);
631 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
632 uint16* data = buf->template base<uint16>();
633 if (data == nullptr) {
634 buf->Unref();
635 return nullptr;
636 }
637 const int64_t in_n = in.half_val().size();
638 auto begin = in.half_val().begin();
639 if (n <= in_n) {
640 std::copy_n(begin, n, data);
641 } else if (in_n > 0) {
642 std::copy_n(begin, in_n, data);
643 const uint16 last = *(data + in_n - 1);
644 std::fill_n(data + in_n, n - in_n, last);
645 } else {
646 std::fill_n(data, n, 0);
647 }
648 return buf;
649}
650
651template <>
652TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
653 int64_t n) {
654 CHECK_GT(n, 0);
655 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
656 uint16* data = buf->template base<uint16>();
657 if (data == nullptr) {
658 buf->Unref();
659 return nullptr;
660 }
661 const int64_t in_n = in.half_val().size();
662 auto begin = in.half_val().begin();
663 if (n <= in_n) {
664 std::copy_n(begin, n, data);
665 } else if (in_n > 0) {
666 std::copy_n(begin, in_n, data);
667 const uint16 last = *(data + in_n - 1);
668 std::fill_n(data + in_n, n - in_n, last);
669 } else {
670 std::fill_n(data, n, 0);
671 }
672 return buf;
673}
674
675// Copies T[n] stored in the buffer "in" into the repeated field in
676// "out" corresponding to type T.
677template <typename T>
678void ToProtoField(const TensorBuffer& in, int64_t n, TensorProto* out) {
679 const T* data = in.base<const T>();
680 // NOTE: T may not the same as
681 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
682 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
683 // critical, we can specialize T=float and do memcpy directly.
684 ProtoHelper<T>::Fill(data, n, out);
685}
686
687void RefIfNonNull(core::RefCounted* buf) {
688 if (buf) buf->Ref();
689}
690
691void UnrefIfNonNull(core::RefCounted* buf) {
692 if (buf) buf->Unref();
693}
694
695} // end namespace
696
697Tensor::Tensor() : Tensor(DT_FLOAT) {}
698
699Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
700
701Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
702 : shape_(shape), buf_(buf) {
703 set_dtype(type);
704 RefIfNonNull(buf);
705}
706
707Tensor::Tensor(DataType type, TensorShape shape,
708 core::RefCountPtr<TensorBuffer> buf)
709 : shape_(std::move(shape)), buf_(buf.release()) {
710 set_dtype(type);
711}
712
713bool Tensor::IsInitialized() const {
714 return (buf_ != nullptr && buf_->data() != nullptr) ||
715 shape_.num_elements() == 0;
716}
717
718void Tensor::CheckType(DataType expected_dtype) const {
719 CHECK_EQ(dtype(), expected_dtype)
720 << " " << DataTypeString(expected_dtype) << " expected, got "
721 << DataTypeString(dtype());
722}
723
724void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
725 CHECK_EQ(dtype(), expected_dtype)
726 << " " << DataTypeString(expected_dtype) << " expected, got "
727 << DataTypeString(dtype());
728 CHECK(IsAligned()) << "ptr = " << base<void>();
729}
730
731void Tensor::CheckIsAlignedAndSingleElement() const {
732 CHECK(IsAligned()) << "Aligned and single element";
733 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
734}
735
736Tensor::~Tensor() { UnrefIfNonNull(buf_); }
737
738Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
739 const TensorShape& shape) {
740 int in_size = DataTypeSize(other.dtype());
741 int out_size = DataTypeSize(dtype);
742 if (in_size == 0) {
743 return errors::InvalidArgument("other tensor has zero-sized data type");
744 }
745 if (out_size == 0) {
746 return errors::InvalidArgument("specified output type is zero-sized");
747 }
748 if (shape.num_elements() * out_size !=
749 other.shape().num_elements() * in_size) {
750 return errors::InvalidArgument(
751 "input and output shapes/data type sizes are not compatible");
752 }
753 shape_ = shape;
754 shape_.set_data_type(dtype);
755 if (buf_ != other.buf_) {
756 UnrefIfNonNull(buf_);
757 buf_ = other.buf_;
758 RefIfNonNull(buf_);
759 }
760 return OkStatus();
761}
762
763// Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
764// For the latter case, we have to make sure that the refcount is
765// one both for the SubBuffer _and_ the underlying TensorBuffer.
766bool Tensor::RefCountIsOne() const {
767 return buf_ != nullptr && buf_->RefCountIsOne() &&
768 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
769}
770
771// The macro CASES() expands to a switch statement conditioned on
772// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
773#define SINGLE_ARG(...) __VA_ARGS__
774#define CASE(TYPE, STMTS) \
775 case DataTypeToEnum<TYPE>::value: { \
776 typedef TF_ATTRIBUTE_UNUSED TYPE T; \
777 STMTS; \
778 break; \
779 }
780#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
781 switch (TYPE_ENUM) { \
782 CASE(float, SINGLE_ARG(STMTS)) \
783 CASE(double, SINGLE_ARG(STMTS)) \
784 CASE(int32, SINGLE_ARG(STMTS)) \
785 CASE(uint8, SINGLE_ARG(STMTS)) \
786 CASE(uint16, SINGLE_ARG(STMTS)) \
787 CASE(uint32, SINGLE_ARG(STMTS)) \
788 CASE(uint64, SINGLE_ARG(STMTS)) \
789 CASE(int16, SINGLE_ARG(STMTS)) \
790 CASE(int8, SINGLE_ARG(STMTS)) \
791 CASE(tstring, SINGLE_ARG(STMTS)) \
792 CASE(complex64, SINGLE_ARG(STMTS)) \
793 CASE(complex128, SINGLE_ARG(STMTS)) \
794 CASE(int64_t, SINGLE_ARG(STMTS)) \
795 CASE(bool, SINGLE_ARG(STMTS)) \
796 CASE(qint32, SINGLE_ARG(STMTS)) \
797 CASE(quint8, SINGLE_ARG(STMTS)) \
798 CASE(qint8, SINGLE_ARG(STMTS)) \
799 CASE(quint16, SINGLE_ARG(STMTS)) \
800 CASE(qint16, SINGLE_ARG(STMTS)) \
801 CASE(bfloat16, SINGLE_ARG(STMTS)) \
802 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
803 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
804 CASE(Variant, SINGLE_ARG(STMTS)) \
805 case DT_INVALID: \
806 INVALID; \
807 break; \
808 default: \
809 DEFAULT; \
810 break; \
811 }
812
813#define CASES(TYPE_ENUM, STMTS) \
814 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
815 , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
816
817Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
818 : shape_(shape), buf_(nullptr) {
819 set_dtype(type);
820 CHECK_NOTNULL(a);
821 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
822 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
823 }
824 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
825 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
826 *this);
827 }
828}
829
830Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
831 const AllocationAttributes& allocation_attr)
832 : shape_(shape), buf_(nullptr) {
833 set_dtype(type);
834 CHECK_NOTNULL(a);
835 if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
836 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
837 }
838 if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
839 buf_ != nullptr && buf_->data() != nullptr) {
840 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
841 LogMemory::UNKNOWN_STEP_ID, *this);
842 }
843}
844
845Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
846 Tensor* out_tensor) {
847 // Avoid crashes due to invalid or unsupported types.
848 CASES_WITH_DEFAULT(
849 type, {}, return errors::InvalidArgument("Type not set"),
850 return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
851 *out_tensor = Tensor(type, shape);
852 return OkStatus();
853}
854
855// NOTE(mrry): The default allocator for a Tensor (when none is specified) is
856// the default CPU allocator for NUMA zone 0. Accessing that currently involves
857// acquiring a lock, which guards initialization of the per-NUMA zone
858// allocators, and becomes highly contended.
859//
860// Note also that it would be better if all Tensor allocations required the user
861// to specify an allocator, for purposes of accounting, etc. However, the
862// default allocator is widely used throughout the codebase and in client code.
863static Allocator* get_default_cpu_allocator() {
864 static Allocator* default_cpu_allocator =
865 cpu_allocator(tsl::port::kNUMANoAffinity);
866 return default_cpu_allocator;
867}
868
869Tensor::Tensor(DataType type, const TensorShape& shape)
870 : Tensor(get_default_cpu_allocator(), type, shape) {}
871
872bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
873 size_t* out_bytes) const {
874 // `this->FillAllocationDescription()` never sets allocated bytes information,
875 // so we can short-circuit the construction of an `AllocationDescription`.
876 return false;
877}
878
879void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
880 AllocationDescription* proto) const {
881 proto->set_requested_bytes(size());
882 proto->set_allocator_name("HostScalarTensorBuffer");
883 proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
884}
885
886template <typename T>
887class SubBuffer : public TensorBuffer {
888 public:
889 // This buffer is an alias to buf[delta, delta + n).
890 SubBuffer(TensorBuffer* buf, int64_t delta, int64_t n)
891 : TensorBuffer(buf->base<T>() + delta),
892 root_(buf->root_buffer()),
893 elem_(n) {
894 // Sanity check. The caller should ensure the sub buffer is valid.
895 CHECK_LE(root_->base<T>(), this->base<T>());
896 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
897 CHECK_LE(this->base<T>(), root_limit);
898 CHECK_LE(this->base<T>() + n, root_limit);
899 // Hold a ref of the underlying root buffer.
900 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
901 root_->Ref();
902 }
903
904 size_t size() const override { return sizeof(T) * elem_; }
905 TensorBuffer* root_buffer() override { return root_; }
906 bool GetAllocatedBytes(size_t* out_bytes) const override {
907 return root_->GetAllocatedBytes(out_bytes);
908 }
909 void FillAllocationDescription(AllocationDescription* proto) const override {
910 root_->FillAllocationDescription(proto);
911 }
912
913 private:
914 TensorBuffer* root_;
915 int64_t elem_;
916
917 ~SubBuffer() override { root_->Unref(); }
918
919 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
920};
921
922Tensor Tensor::Slice(int64_t start, int64_t limit) const {
923 CHECK_GE(dims(), 1);
924 CHECK_LE(0, start);
925 CHECK_LE(start, limit);
926 int64_t dim0_size = shape_.dim_size(0);
927 CHECK_LE(limit, dim0_size);
928 if ((start == 0) && (limit == dim0_size)) {
929 return *this;
930 }
931 Tensor ret;
932 ret.shape_ = shape_;
933 ret.set_dtype(dtype());
934 ret.buf_ = nullptr;
935 if (dim0_size > 0) {
936 const int64_t elems_per_dim0 = NumElements() / dim0_size;
937 const int64_t delta = start * elems_per_dim0;
938 dim0_size = limit - start;
939 ret.shape_.set_dim(0, dim0_size);
940 const int64_t num_elems = dim0_size * elems_per_dim0;
941 if (buf_) {
942 DataType dt = dtype();
943 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
944 }
945 }
946 return ret;
947}
948
949Tensor Tensor::SubSlice(int64_t index) const {
950 CHECK_GE(dims(), 1); // Crash ok.
951 CHECK_LE(0, index); // Crash ok.
952 int64_t dim0_size = shape_.dim_size(0);
953 CHECK_LE(index, dim0_size); // Crash ok.
954 Tensor ret;
955 ret.shape_ = shape_;
956 ret.shape_.RemoveDim(0);
957 ret.set_dtype(dtype());
958 ret.buf_ = nullptr;
959 if (dim0_size > 0) {
960 const int64_t elems_per_dim0 = NumElements() / dim0_size;
961 const int64_t delta = index * elems_per_dim0;
962 const int64_t num_elems = elems_per_dim0;
963 if (buf_) {
964 DataType dt = dtype();
965 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
966 }
967 }
968 return ret;
969}
970
971bool Tensor::FromProto(const TensorProto& proto) {
972 return FromProto(get_default_cpu_allocator(), proto);
973}
974
975bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
976 CHECK_NOTNULL(a);
977 TensorBuffer* p = nullptr;
978 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
979 if (proto.dtype() == DT_INVALID) return false;
980 TensorShape shape(proto.tensor_shape());
981 const int64_t N = shape.num_elements();
982 if (N > 0 && proto.dtype()) {
983 bool dtype_error = false;
984 if (!proto.tensor_content().empty()) {
985 const auto& content = proto.tensor_content();
986 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
987 dtype_error = true, dtype_error = true);
988 } else {
989 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
990 dtype_error = true, dtype_error = true);
991 }
992 if (dtype_error || p == nullptr) return false;
993 } else {
994 // Handle the case of empty tensors (N = 0) or tensors with incomplete shape
995 // (N = -1). All other values of `shape.num_elements()` should be invalid by
996 // construction.
997 // Here, we just need to validate that the `proto.dtype()` value is valid.
998 bool dtype_error = false;
999 CASES_WITH_DEFAULT(proto.dtype(), break, dtype_error = true,
1000 dtype_error = true);
1001 if (dtype_error) return false;
1002 }
1003 shape_ = shape;
1004 set_dtype(proto.dtype());
1005 UnrefIfNonNull(buf_);
1006 buf_ = p;
1007 // TODO(misard) add tracking of which kernels and steps are calling
1008 // FromProto.
1009 if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
1010 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
1011 LogMemory::UNKNOWN_STEP_ID, *this);
1012 }
1013 return true;
1014}
1015
1016void Tensor::AsProtoField(TensorProto* proto) const {
1017 proto->Clear();
1018 shape_.AsProto(proto->mutable_tensor_shape());
1019 proto->set_dtype(dtype());
1020 if (buf_) {
1021 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
1022 }
1023}
1024
1025void Tensor::AsProtoTensorContent(TensorProto* proto) const {
1026 proto->Clear();
1027 proto->set_dtype(dtype());
1028 shape_.AsProto(proto->mutable_tensor_shape());
1029 if (buf_) {
1030 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
1031 proto->mutable_tensor_content()));
1032 }
1033}
1034
1035size_t Tensor::TotalBytes() const {
1036 if (shape_.num_elements() == 0) return 0;
1037 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
1038 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
1039 return 0; // Makes compiler happy.
1040}
1041
1042size_t Tensor::AllocatedBytes() const {
1043 if (buf_) {
1044 size_t ret;
1045 if (buf_->GetAllocatedBytes(&ret)) {
1046 return ret;
1047 }
1048 }
1049 return TotalBytes();
1050}
1051
1052bool Tensor::CanUseDMA() const {
1053 CASES(dtype(), return is_simple_type<T>::value);
1054 return false; // Makes compiler happy.
1055}
1056
1057#undef CASES
1058#undef CASE
1059
1060namespace {
1061
1062// StrCat and StrAppend don't support Eigen::half directly at the moment, and
1063// we would like to keep them compatible with their absl counterparts, for ease
1064// of migration. We could rely on errors::internal::PrepareForStrCat() but the
1065// logic is so simple we can just replicate it here, where it is close to its
1066// usage and easy to change later. And there's the extra benefit of not
1067// accessing an 'internal' namespace.
1068inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1069 bool print_v2) {
1070 return a;
1071}
1072inline string PrintOneElement(const tstring& a, bool print_v2) {
1073 if (print_v2) {
1074 return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1075 } else {
1076 return absl::Utf8SafeCEscape(a);
1077 }
1078}
1079inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1080 return static_cast<float>(h);
1081}
1082
1083inline float PrintOneElement(bfloat16 f, bool print_v2) {
1084 return static_cast<float>(f);
1085}
1086
1087// Print from left dim to right dim recursively.
1088template <typename T>
1089void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1090 int64_t limit, int shape_size, const T* data,
1091 int64_t* data_index, string* result) {
1092 if (*data_index >= limit) return;
1093 int64_t element_count = shape[dim_index];
1094 // We have reached the right-most dimension of the tensor.
1095 if (dim_index == shape_size - 1) {
1096 for (int64_t i = 0; i < element_count; i++) {
1097 if (*data_index >= limit) {
1098 // If not enough elements has been printed, append "...".
1099 if (dim_index != 0) {
1100 strings::StrAppend(result, "...");
1101 }
1102 return;
1103 }
1104 if (i > 0) strings::StrAppend(result, " ");
1105 strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1106 }
1107 return;
1108 }
1109 // Loop every element of one dim.
1110 for (int64_t i = 0; i < element_count; i++) {
1111 bool flag = false;
1112 if (*data_index < limit) {
1113 strings::StrAppend(result, "[");
1114 flag = true;
1115 }
1116 // As for each element, print the sub-dim.
1117 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1118 result);
1119 if (*data_index < limit || flag) {
1120 strings::StrAppend(result, "]");
1121 flag = false;
1122 }
1123 }
1124}
1125
1126// Appends the spacing between elements for a given dim onto a result string
1127void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1128 if (dim_index == num_dims - 1) {
1129 strings::StrAppend(result, " ");
1130 return;
1131 }
1132 for (int j = 0; j < num_dims - dim_index - 1; j++) {
1133 strings::StrAppend(result, "\n");
1134 }
1135 for (int j = 0; j <= dim_index; j++) {
1136 strings::StrAppend(result, " ");
1137 }
1138}
1139
1140// Print from left dim to right dim recursively.
1141template <typename T>
1142void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1143 int64_t num_elts_at_ends, int num_dims, const T* data,
1144 int64_t data_index, string* result) {
1145 // We have recursed beyond all the dimensions into a single element
1146 // of the tensor.
1147 if (dim_index == num_dims) {
1148 strings::StrAppend(result, PrintOneElement(data[data_index], true));
1149 return;
1150 }
1151
1152 strings::StrAppend(result, "[");
1153 int64_t element_count = shape[dim_index];
1154 int64_t start_of_end =
1155 std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1156
1157 // Loop every element of one dim.
1158 int64_t elements_per_iter = 1;
1159 for (int i = dim_index + 1; i < num_dims; i++) {
1160 elements_per_iter *= shape[i];
1161 }
1162 for (int64_t i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1163 if (i > 0) {
1164 PrintDimSpacing(dim_index, num_dims, result);
1165 }
1166
1167 // As for each element, print the sub-dim.
1168 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1169 data_index + elements_per_iter * i, result);
1170 }
1171 if (element_count > 2 * num_elts_at_ends) {
1172 PrintDimSpacing(dim_index, num_dims, result);
1173 strings::StrAppend(result, "...");
1174 }
1175 for (int64_t i = start_of_end; i < element_count; i++) {
1176 // As for each element, print the sub-dim.
1177 PrintDimSpacing(dim_index, num_dims, result);
1178 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1179 data_index + elements_per_iter * i, result);
1180 }
1181
1182 strings::StrAppend(result, "]");
1183}
1184
1185template <typename T>
1186string SummarizeArray(int64_t limit, int64_t num_elts,
1187 const TensorShape& tensor_shape, const char* data,
1188 const bool print_v2) {
1189 string ret;
1190 const T* array = reinterpret_cast<const T*>(data);
1191
1192 const gtl::InlinedVector<int64_t, 4> shape = tensor_shape.dim_sizes();
1193 if (shape.empty()) {
1194 for (int64_t i = 0; i < limit; ++i) {
1195 if (i > 0) strings::StrAppend(&ret, " ");
1196 strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1197 }
1198 if (num_elts > limit) strings::StrAppend(&ret, "...");
1199 return ret;
1200 }
1201 if (print_v2) {
1202 const int num_dims = tensor_shape.dims();
1203 PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1204 } else {
1205 int64_t data_index = 0;
1206 const int shape_size = tensor_shape.dims();
1207 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1208
1209 if (num_elts > limit) strings::StrAppend(&ret, "...");
1210 }
1211
1212 return ret;
1213}
1214} // namespace
1215
1216string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const {
1217 const int64_t num_elts = NumElements();
1218 if (max_entries < 0) {
1219 max_entries = num_elts;
1220 }
1221 size_t limit = std::min(max_entries, num_elts);
1222 if ((limit > 0) && (buf_ == nullptr)) {
1223 return strings::StrCat("uninitialized Tensor of ", num_elts,
1224 " elements of type ", dtype());
1225 }
1226 const char* data = limit > 0 ? tensor_data().data() : nullptr;
1227 switch (dtype()) {
1228 case DT_BFLOAT16:
1229 return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1230 break;
1231 case DT_HALF:
1232 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1233 print_v2);
1234 break;
1235 case DT_FLOAT:
1236 return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1237 break;
1238 case DT_DOUBLE:
1239 return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1240 break;
1241 case DT_UINT32:
1242 return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1243 break;
1244 case DT_INT32:
1245 return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1246 break;
1247 case DT_UINT8:
1248 case DT_QUINT8:
1249 return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1250 break;
1251 case DT_UINT16:
1252 case DT_QUINT16:
1253 return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1254 break;
1255 case DT_INT16:
1256 case DT_QINT16:
1257 return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1258 break;
1259 case DT_INT8:
1260 case DT_QINT8:
1261 return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1262 break;
1263 case DT_UINT64:
1264 return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1265 break;
1266 case DT_INT64:
1267 return SummarizeArray<int64_t>(limit, num_elts, shape_, data, print_v2);
1268 break;
1269 case DT_BOOL:
1270 // TODO(tucker): Is it better to emit "True False..."? This
1271 // will emit "1 0..." which is more compact.
1272 return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1273 break;
1274 case DT_STRING:
1275 return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1276 break;
1277 default: {
1278 // All irregular cases
1279 string ret;
1280 if (print_v2 && (dims() > 0)) {
1281 strings::StrAppend(&ret, "[");
1282 }
1283 // TODO(irving): Don't call flat every time around this
1284 // loop.
1285 for (size_t i = 0; i < limit; ++i) {
1286 if (i > 0) strings::StrAppend(&ret, " ");
1287 switch (dtype()) {
1288 case DT_VARIANT: {
1289 const Variant& v = flat<Variant>()(i);
1290 strings::StrAppend(&ret, "<", v.SummarizeValue(), ">");
1291 } break;
1292 case DT_RESOURCE: {
1293 const ResourceHandle& r = flat<ResourceHandle>()(i);
1294 strings::StrAppend(&ret, "<", r.SummarizeValue(), ">");
1295 } break;
1296 default:
1297 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1298 // complex64, quantized).
1299 strings::StrAppend(&ret, "?");
1300 }
1301 }
1302 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1303 if (print_v2 && (dims() > 0)) {
1304 strings::StrAppend(&ret, "]");
1305 }
1306 return ret;
1307 }
1308 }
1309}
1310
1311StringPiece Tensor::tensor_data() const {
1312 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1313 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1314}
1315
1316void* Tensor::data() const {
1317 if (buf_ == nullptr) return nullptr; // Don't die for empty tensors
1318 return static_cast<void*>(buf_->data());
1319}
1320
1321bool Tensor::SharesBufferWith(const Tensor& b) const {
1322 return buf_ != nullptr && b.buf_ != nullptr &&
1323 buf_->root_buffer() == b.buf_->root_buffer();
1324}
1325
1326string Tensor::DebugString(int num_values) const {
1327 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1328 " shape: ", shape().DebugString(),
1329 " values: ", SummarizeValue(num_values), ">");
1330}
1331
1332string Tensor::DeviceSafeDebugString() const {
1333 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1334 " shape: ", shape().DebugString(), ">");
1335}
1336
1337void Tensor::FillDescription(TensorDescription* description) const {
1338 description->set_dtype(dtype());
1339 shape().AsProto(description->mutable_shape());
1340 if (buf_ != nullptr && buf_->data() != nullptr) {
1341 buf_->FillAllocationDescription(
1342 description->mutable_allocation_description());
1343 }
1344}
1345
1346gtl::InlinedVector<int64_t, 4> Tensor::ComputeFlatInnerDims(
1347 gtl::ArraySlice<int64_t> orig, int64_t num_out_dims) {
1348 gtl::InlinedVector<int64_t, 4> out_dims(num_out_dims, 0);
1349 int64_t offset = orig.size() - num_out_dims;
1350 for (int64_t out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1351 const int64_t in_dim = out_dim + offset;
1352 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1353 }
1354 for (int64_t in_dim = 0; in_dim < offset; ++in_dim) {
1355 out_dims[0] *= orig[in_dim];
1356 }
1357 return out_dims;
1358}
1359
1360gtl::InlinedVector<int64_t, 4> Tensor::ComputeFlatOuterDims(
1361 gtl::ArraySlice<int64_t> orig, int64_t num_out_dims) {
1362 gtl::InlinedVector<int64_t, 4> out_dims(num_out_dims, 0);
1363 for (int64_t out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1364 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1365 }
1366 for (int64_t in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1367 out_dims[num_out_dims - 1] *= orig[in_dim];
1368 }
1369 return out_dims;
1370}
1371
1372} // namespace tensorflow
1373