1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implementation notes: |
17 | // |
18 | // Tensor.cc uses a few templated classes and structs to facilitate |
19 | // implementation of the Tensor class. |
20 | // |
21 | // * Buffer<T>: provides the implementation for a typed array T[n]. |
22 | // The array is allocated by the given allocator. It runs T's |
23 | // default constructors and destructors when T is not a simple type |
24 | // (e.g., string.), and skips them otherwise. |
25 | // |
26 | // * Helper<T>: provides various routines given type T. The routines |
27 | // includes running the constructor and destructor of T[], encoding |
28 | // an decoding T[] into/from a Cord, etc. |
29 | |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | |
32 | #include <utility> |
33 | |
34 | #include "absl/strings/escaping.h" |
35 | #include "tensorflow/core/framework/allocation_description.pb.h" |
36 | #include "tensorflow/core/framework/log_memory.h" |
37 | #include "tensorflow/core/framework/resource_handle.h" |
38 | #include "tensorflow/core/framework/resource_handle.pb.h" |
39 | #include "tensorflow/core/framework/tensor.pb.h" |
40 | #include "tensorflow/core/framework/tensor_description.pb.h" |
41 | #include "tensorflow/core/framework/type_traits.h" |
42 | #include "tensorflow/core/framework/typed_allocator.h" |
43 | #include "tensorflow/core/framework/types.h" |
44 | #include "tensorflow/core/framework/types.pb.h" |
45 | #include "tensorflow/core/framework/variant.h" |
46 | #include "tensorflow/core/framework/variant_encode_decode.h" |
47 | #include "tensorflow/core/framework/variant_op_registry.h" |
48 | #include "tensorflow/core/framework/variant_tensor_data.h" |
49 | #include "tensorflow/core/lib/core/coding.h" |
50 | #include "tensorflow/core/lib/core/errors.h" |
51 | #include "tensorflow/core/lib/core/status.h" |
52 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
53 | #include "tensorflow/core/lib/strings/str_util.h" |
54 | #include "tensorflow/core/lib/strings/strcat.h" |
55 | #include "tensorflow/core/platform/errors.h" |
56 | #include "tensorflow/core/platform/logging.h" |
57 | #include "tensorflow/core/platform/macros.h" |
58 | #include "tensorflow/core/platform/protobuf.h" |
59 | #include "tensorflow/core/platform/tensor_coding.h" |
60 | #include "tensorflow/core/platform/types.h" |
61 | |
62 | namespace tensorflow { |
63 | |
64 | // Allow Tensors to be stored inside Variants with automatic |
65 | // encoding/decoding when those Variants are themselves being decoded |
66 | // in a Tensor's FromProto. |
67 | // |
68 | // NOTE(mrry): The corresponding "copy function" registrations can be found in |
69 | // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime |
70 | // code). |
71 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor" ); |
72 | |
73 | bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const { |
74 | AllocationDescription allocation_description; |
75 | FillAllocationDescription(&allocation_description); |
76 | if (allocation_description.allocated_bytes() > 0) { |
77 | *out_bytes = allocation_description.allocated_bytes(); |
78 | return true; |
79 | } else { |
80 | return false; |
81 | } |
82 | } |
83 | |
84 | namespace { |
85 | |
86 | // An un-templated base class for Buffer. |
87 | class BufferBase : public TensorBuffer { |
88 | public: |
89 | explicit BufferBase(Allocator* alloc, void* data_ptr) |
90 | : TensorBuffer(data_ptr), alloc_(alloc) {} |
91 | |
92 | TensorBuffer* root_buffer() override { return this; } |
93 | |
94 | bool GetAllocatedBytes(size_t* out_bytes) const override { |
95 | if (alloc_->TracksAllocationSizes()) { |
96 | *out_bytes = alloc_->AllocatedSize(data()); |
97 | return *out_bytes > 0; |
98 | } else { |
99 | return false; |
100 | } |
101 | } |
102 | |
103 | void FillAllocationDescription(AllocationDescription* proto) const override { |
104 | void* data_ptr = data(); |
105 | int64_t rb = size(); |
106 | proto->set_requested_bytes(rb); |
107 | proto->set_allocator_name(alloc_->Name()); |
108 | proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr)); |
109 | if (alloc_->TracksAllocationSizes()) { |
110 | int64_t ab = alloc_->AllocatedSize(data_ptr); |
111 | proto->set_allocated_bytes(ab); |
112 | int64_t id = alloc_->AllocationId(data_ptr); |
113 | if (id > 0) { |
114 | proto->set_allocation_id(id); |
115 | } |
116 | if (RefCountIsOne()) { |
117 | proto->set_has_single_reference(true); |
118 | } |
119 | } |
120 | } |
121 | |
122 | // Returns the type of the underlying memory. |
123 | AllocatorMemoryType GetMemoryType() const override { |
124 | return alloc_->GetMemoryType(); |
125 | } |
126 | |
127 | protected: |
128 | void RecordDeallocation() { |
129 | LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()), |
130 | alloc_->Name()); |
131 | } |
132 | |
133 | Allocator* const alloc_; |
134 | }; |
135 | |
136 | // Typed ref-counted buffer: T[n]. |
137 | template <typename T> |
138 | class Buffer : public BufferBase { |
139 | public: |
140 | Buffer(Allocator* a, int64_t n); |
141 | Buffer(Allocator* a, int64_t n, const AllocationAttributes& allocation_attr); |
142 | |
143 | size_t size() const override { return sizeof(T) * elem_; } |
144 | |
145 | private: |
146 | int64_t elem_; |
147 | |
148 | ~Buffer() override; |
149 | |
150 | TF_DISALLOW_COPY_AND_ASSIGN(Buffer); |
151 | }; |
152 | |
153 | void LogUnexpectedSize(int64_t actual, int64_t expected) { |
154 | LOG(ERROR) << "Input size was " << actual << " and expected " << expected; |
155 | } |
156 | |
157 | bool MemoryLoggingEnabled() { |
158 | static bool memory_logging_enabled = LogMemory::IsEnabled(); |
159 | return memory_logging_enabled; |
160 | } |
161 | |
162 | // A set of helper functions depending on T. |
163 | template <typename T> |
164 | struct Helper { |
165 | // By default, we assume T is a simple type (float, int32, etc.) |
166 | static_assert(is_simple_type<T>::value, "T is not a simple type." ); |
167 | typedef protobuf::RepeatedField<T> RepeatedFieldType; |
168 | |
169 | // Encoder of simple type T to a string. We do a copy. |
170 | template <typename Destination> |
171 | static void Encode(TensorBuffer* in, int64_t n, Destination* out) { |
172 | DCHECK_EQ(in->size(), sizeof(T) * n); |
173 | port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in, |
174 | out); |
175 | } |
176 | |
177 | // Decoder of simple type T. Copy the bytes from "in" into the |
178 | // tensor buffer. |
179 | template <typename Source> |
180 | static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) { |
181 | if (in.size() != sizeof(T) * n) { |
182 | LogUnexpectedSize(in.size(), sizeof(T) * n); |
183 | return nullptr; |
184 | } |
185 | Buffer<T>* buf = new Buffer<T>(a, n); |
186 | char* data = buf->template base<char>(); |
187 | if (data == nullptr) { |
188 | buf->Unref(); |
189 | return nullptr; |
190 | } |
191 | port::CopyToArray(in, data); |
192 | return buf; |
193 | } |
194 | |
195 | // Memory usage. |
196 | static int64_t TotalBytes(TensorBuffer* in, int64_t n) { |
197 | DCHECK_EQ(in->size(), sizeof(T) * n); |
198 | return in->size(); |
199 | } |
200 | }; |
201 | |
202 | // Helper specialization for string (the only non-simple type we |
203 | // support). |
204 | template <> |
205 | struct Helper<tstring> { |
206 | // Proto message uses RepeatedFieldType to hold repeated T. |
207 | typedef protobuf::RepeatedPtrField<string> RepeatedFieldType; |
208 | |
209 | // Encodes "n" elements of type string stored in "in" into Cord |
210 | // "out", which is usually the TensorProto::tensor_content. |
211 | template <typename Destination> |
212 | static void Encode(TensorBuffer* in, int64_t n, Destination* out) { |
213 | port::EncodeStringList(in->base<const tstring>(), n, out); |
214 | } |
215 | |
216 | // Decodes "n" elements of type string from "in" and constructs a |
217 | // buffer out of it. Returns nullptr if the decoding fails. "in" is |
218 | // usually the TensorProto::tensor_content. |
219 | template <typename Source> |
220 | static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) { |
221 | Buffer<tstring>* buf = new Buffer<tstring>(a, n); |
222 | tstring* strings = buf->template base<tstring>(); |
223 | if (strings == nullptr || !port::DecodeStringList(in, strings, n)) { |
224 | buf->Unref(); |
225 | return nullptr; |
226 | } |
227 | return buf; |
228 | } |
229 | |
230 | // Returns the estimated memory usage of "n" elements of type T |
231 | // stored in buffer "in". |
232 | static int64_t TotalBytes(TensorBuffer* in, int n) { |
233 | int64_t tot = in->size(); |
234 | DCHECK_EQ(tot, sizeof(tstring) * n); |
235 | const tstring* p = in->base<const tstring>(); |
236 | for (int i = 0; i < n; ++i, ++p) tot += p->size(); |
237 | return tot; |
238 | } |
239 | }; |
240 | |
241 | template <> |
242 | struct Helper<ResourceHandle> { |
243 | // Proto message uses RepeatedFieldType to hold repeated T. |
244 | typedef protobuf::RepeatedPtrField<string> RepeatedFieldType; |
245 | |
246 | // Encodes "n" elements of type ResourceHandle stored in "in" into destination |
247 | // "out", which is usually the TensorProto::tensor_content. |
248 | template <typename Destination> |
249 | static void Encode(TensorBuffer* in, int64_t n, Destination* out) { |
250 | EncodeResourceHandleList(in->base<const ResourceHandle>(), n, |
251 | port::NewStringListEncoder(out)); |
252 | } |
253 | |
254 | // Decodes "n" elements of type string from "in" and constructs a |
255 | // buffer out of it. Returns nullptr if the decoding fails. "in" is |
256 | // usually the TensorProto::tensor_content. |
257 | template <typename Source> |
258 | static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) { |
259 | auto* buf = new Buffer<ResourceHandle>(a, n); |
260 | ResourceHandle* ps = buf->template base<ResourceHandle>(); |
261 | if (ps == nullptr || |
262 | !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) { |
263 | buf->Unref(); |
264 | return nullptr; |
265 | } |
266 | return buf; |
267 | } |
268 | |
269 | // Returns the estimated memory usage of "n" elements of type T |
270 | // stored in buffer "in". |
271 | static int64_t TotalBytes(TensorBuffer* in, int n) { |
272 | return n * sizeof(ResourceHandle); |
273 | } |
274 | }; |
275 | |
276 | template <> |
277 | struct Helper<Variant> { |
278 | // Encodes "n" elements of type Variant stored in "in" into destination |
279 | // "out", which is usually the TensorProto::tensor_content. |
280 | template <typename Destination> |
281 | static void Encode(TensorBuffer* in, int64_t n, Destination* out) { |
282 | EncodeVariantList(in->base<const Variant>(), n, |
283 | port::NewStringListEncoder(out)); |
284 | } |
285 | |
286 | // Decodes "n" elements of type Variant from "in" and constructs a |
287 | // buffer out of it. Returns nullptr if the decoding fails. "in" is |
288 | // usually the TensorProto::tensor_content. |
289 | template <typename Source> |
290 | static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) { |
291 | auto* buf = new Buffer<Variant>(a, n); |
292 | Variant* ps = buf->template base<Variant>(); |
293 | if (ps == nullptr || |
294 | !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) { |
295 | buf->Unref(); |
296 | return nullptr; |
297 | } |
298 | return buf; |
299 | } |
300 | |
301 | // Returns the estimated memory usage of "n" elements of type T |
302 | // stored in buffer "in". |
303 | static int64_t TotalBytes(TensorBuffer* in, int n) { |
304 | return n * sizeof(Variant); |
305 | } |
306 | }; |
307 | |
308 | template <typename T> |
309 | struct ProtoHelper {}; |
310 | |
311 | // For a C++ type "T" (float, double, int32, etc.), the repeated field |
312 | // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float, |
313 | // int32, string, etc) in the TensorProto is used for serializing the |
314 | // tensor of type "T". |
315 | #define PROTO_TRAITS(T, F, N) \ |
316 | template <> \ |
317 | struct ProtoHelper<T> { \ |
318 | typedef Helper<F>::RepeatedFieldType FieldType; \ |
319 | static FieldType::const_iterator Begin(const TensorProto& proto) { \ |
320 | return proto.N##_val().begin(); \ |
321 | } \ |
322 | static size_t NumElements(const TensorProto& proto) { \ |
323 | return proto.N##_val().size(); \ |
324 | } \ |
325 | static void Fill(const T* data, size_t n, TensorProto* proto) { \ |
326 | typename ProtoHelper<T>::FieldType copy(data, data + n); \ |
327 | proto->mutable_##N##_val()->Swap(©); \ |
328 | } \ |
329 | }; |
330 | PROTO_TRAITS(float, float, float); |
331 | PROTO_TRAITS(double, double, double); |
332 | PROTO_TRAITS(int32, int32, int); |
333 | PROTO_TRAITS(uint8, int32, int); |
334 | PROTO_TRAITS(uint16, int32, int); |
335 | PROTO_TRAITS(uint32, uint32, uint32); |
336 | PROTO_TRAITS(int16, int32, int); |
337 | PROTO_TRAITS(int8, int32, int); |
338 | PROTO_TRAITS(bool, bool, bool); |
339 | PROTO_TRAITS(tstring, tstring, string); |
340 | PROTO_TRAITS(qint8, int32, int); |
341 | PROTO_TRAITS(quint8, int32, int); |
342 | PROTO_TRAITS(qint16, int32, int); |
343 | PROTO_TRAITS(quint16, int32, int); |
344 | #undef PROTO_TRAITS |
345 | |
346 | template <> |
347 | struct ProtoHelper<int64_t> { |
348 | static protobuf::RepeatedField<int64_t>::const_iterator Begin( |
349 | const TensorProto& proto) { |
350 | return proto.int64_val().begin(); |
351 | } |
352 | static size_t NumElements(const TensorProto& proto) { |
353 | return proto.int64_val().size(); |
354 | } |
355 | static void Fill(const int64_t* data, size_t n, TensorProto* proto) { |
356 | protobuf::RepeatedField<protobuf_int64> copy(data, data + n); |
357 | proto->mutable_int64_val()->Swap(©); |
358 | } |
359 | }; |
360 | |
361 | template <> |
362 | struct ProtoHelper<uint64> { |
363 | static protobuf::RepeatedField<uint64_t>::const_iterator Begin( |
364 | const TensorProto& proto) { |
365 | return proto.uint64_val().begin(); |
366 | } |
367 | static size_t NumElements(const TensorProto& proto) { |
368 | return proto.uint64_val().size(); |
369 | } |
370 | static void Fill(const uint64* data, size_t n, TensorProto* proto) { |
371 | protobuf::RepeatedField<protobuf_uint64> copy(data, data + n); |
372 | proto->mutable_uint64_val()->Swap(©); |
373 | } |
374 | }; |
375 | |
376 | template <> |
377 | struct ProtoHelper<ResourceHandle> { |
378 | static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin( |
379 | const TensorProto& proto) { |
380 | return proto.resource_handle_val().begin(); |
381 | } |
382 | static size_t NumElements(const TensorProto& proto) { |
383 | return proto.resource_handle_val().size(); |
384 | } |
385 | static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) { |
386 | auto* handles = proto->mutable_resource_handle_val(); |
387 | handles->Clear(); |
388 | for (size_t i = 0; i < n; i++) { |
389 | data[i].AsProto(handles->Add()); |
390 | } |
391 | } |
392 | }; |
393 | |
394 | template <> |
395 | struct ProtoHelper<Variant> { |
396 | static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator |
397 | Begin(const TensorProto& proto) { |
398 | return proto.variant_val().begin(); |
399 | } |
400 | static size_t NumElements(const TensorProto& proto) { |
401 | return proto.variant_val().size(); |
402 | } |
403 | static void Fill(const Variant* data, size_t n, TensorProto* proto) { |
404 | auto* variant_values = proto->mutable_variant_val(); |
405 | variant_values->Clear(); |
406 | for (size_t i = 0; i < n; ++i) { |
407 | VariantTensorData tmp; |
408 | data[i].Encode(&tmp); |
409 | tmp.ToProto(variant_values->Add()); |
410 | } |
411 | } |
412 | }; |
413 | |
414 | template <> |
415 | struct ProtoHelper<complex64> { |
416 | typedef Helper<float>::RepeatedFieldType FieldType; |
417 | static const complex64* Begin(const TensorProto& proto) { |
418 | return reinterpret_cast<const complex64*>(proto.scomplex_val().data()); |
419 | } |
420 | static size_t NumElements(const TensorProto& proto) { |
421 | return proto.scomplex_val().size() / 2; |
422 | } |
423 | static void Fill(const complex64* data, size_t n, TensorProto* proto) { |
424 | const float* p = reinterpret_cast<const float*>(data); |
425 | FieldType copy(p, p + n * 2); |
426 | proto->mutable_scomplex_val()->Swap(©); |
427 | } |
428 | }; |
429 | |
430 | template <> |
431 | struct ProtoHelper<complex128> { |
432 | typedef Helper<double>::RepeatedFieldType FieldType; |
433 | static const complex128* Begin(const TensorProto& proto) { |
434 | return reinterpret_cast<const complex128*>(proto.dcomplex_val().data()); |
435 | } |
436 | static size_t NumElements(const TensorProto& proto) { |
437 | return proto.dcomplex_val().size() / 2; |
438 | } |
439 | static void Fill(const complex128* data, size_t n, TensorProto* proto) { |
440 | const double* p = reinterpret_cast<const double*>(data); |
441 | FieldType copy(p, p + n * 2); |
442 | proto->mutable_dcomplex_val()->Swap(©); |
443 | } |
444 | }; |
445 | |
446 | template <> |
447 | struct ProtoHelper<qint32> { |
448 | typedef Helper<int32>::RepeatedFieldType FieldType; |
449 | static const qint32* Begin(const TensorProto& proto) { |
450 | return reinterpret_cast<const qint32*>(proto.int_val().data()); |
451 | } |
452 | static size_t NumElements(const TensorProto& proto) { |
453 | return proto.int_val().size(); |
454 | } |
455 | static void Fill(const qint32* data, size_t n, TensorProto* proto) { |
456 | const int32* p = reinterpret_cast<const int32*>(data); |
457 | FieldType copy(p, p + n); |
458 | proto->mutable_int_val()->Swap(©); |
459 | } |
460 | }; |
461 | |
462 | template <> |
463 | struct ProtoHelper<bfloat16> { |
464 | static void Fill(const bfloat16* data, size_t n, TensorProto* proto) { |
465 | proto->mutable_half_val()->Reserve(n); |
466 | for (size_t i = 0; i < n; ++i) { |
467 | proto->mutable_half_val()->AddAlreadyReserved( |
468 | Eigen::numext::bit_cast<uint16>(data[i])); |
469 | } |
470 | } |
471 | }; |
472 | |
473 | template <> |
474 | struct ProtoHelper<Eigen::half> { |
475 | static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) { |
476 | proto->mutable_half_val()->Reserve(n); |
477 | for (size_t i = 0; i < n; ++i) { |
478 | proto->mutable_half_val()->AddAlreadyReserved( |
479 | Eigen::numext::bit_cast<uint16>(data[i])); |
480 | } |
481 | } |
482 | }; |
483 | |
484 | template <typename T> |
485 | Buffer<T>::Buffer(Allocator* a, int64_t n) |
486 | : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())), |
487 | elem_(n) {} |
488 | |
489 | template <typename T> |
490 | Buffer<T>::Buffer(Allocator* a, int64_t n, |
491 | const AllocationAttributes& allocation_attr) |
492 | : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)), |
493 | elem_(n) {} |
494 | |
495 | template <typename T> |
496 | Buffer<T>::~Buffer() { |
497 | if (data()) { |
498 | if (MemoryLoggingEnabled()) { |
499 | RecordDeallocation(); |
500 | } |
501 | TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_); |
502 | } |
503 | } |
504 | |
505 | // Allocates a T[n] buffer. Fills in the buffer with repeated values |
506 | // in "in". If "in" has less values than "n", fills the rest of T[n] |
507 | // with the last value. If "in" has no values, fills T[n] with the |
508 | // default value for T. |
509 | // |
510 | // This routine is using the typed fields (float_val, etc.) in the |
511 | // tensor proto as opposed to the untyped binary representation |
512 | // (tensor_content). This is used when we expect the TensorProto is |
513 | // used by a client program which may not know how to encode a tensor |
514 | // in the compact binary representation. |
515 | template <typename T> |
516 | TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) { |
517 | CHECK_GT(n, 0); |
518 | Buffer<T>* buf = new Buffer<T>(a, n); |
519 | T* data = buf->template base<T>(); |
520 | if (data == nullptr) { |
521 | buf->Unref(); |
522 | return nullptr; |
523 | } |
524 | |
525 | const int64_t in_n = ProtoHelper<T>::NumElements(in); |
526 | if (in_n <= 0) { |
527 | std::fill_n(data, n, T()); |
528 | } else { |
529 | auto begin = ProtoHelper<T>::Begin(in); |
530 | if (n <= in_n) { |
531 | std::copy_n(begin, n, data); |
532 | } else { |
533 | std::copy_n(begin, in_n, data); |
534 | if (std::is_trivially_copyable<T>::value) { |
535 | const T last = *(data + in_n - 1); |
536 | std::fill_n(data + in_n, n - in_n, last); |
537 | } else { |
538 | const T& last = *(data + in_n - 1); |
539 | std::fill_n(data + in_n, n - in_n, last); |
540 | } |
541 | } |
542 | } |
543 | |
544 | return buf; |
545 | } |
546 | |
547 | // Separate implementation for `ResourceHandle` to handle the case when the |
548 | // proto for the resource is invalid. See `resource_handle.h` constructor and |
549 | // static factory builder. |
550 | template <> |
551 | TensorBuffer* FromProtoField<ResourceHandle>(Allocator* a, |
552 | const TensorProto& in, int64_t n) { |
553 | CHECK_GT(n, 0); |
554 | Buffer<ResourceHandle>* buf = new Buffer<ResourceHandle>(a, n); |
555 | ResourceHandle* data = buf->template base<ResourceHandle>(); |
556 | if (data == nullptr) { |
557 | buf->Unref(); |
558 | return nullptr; |
559 | } |
560 | const int64_t in_n = ProtoHelper<ResourceHandle>::NumElements(in); |
561 | if (in_n <= 0) { |
562 | std::fill_n(data, n, ResourceHandle()); |
563 | } else { |
564 | // If tensor shape says we have n < in_n elements in the output tensor |
565 | // then make sure to only decode the first n out of the in_n elements in the |
566 | // in tensors. In all other cases, we decode all in_n elements of in and set |
567 | // the remaining elements up to n to be the default ResourceHandle() value. |
568 | const int64_t real_n = n < in_n ? n : in_n; |
569 | for (int64_t i = 0; i < real_n; ++i) { |
570 | Status s = ResourceHandle::BuildResourceHandle(in.resource_handle_val(i), |
571 | &data[i]); |
572 | if (!s.ok()) { |
573 | LOG(ERROR) << "Could not decode resource handle from proto \"" |
574 | << in.resource_handle_val(i).ShortDebugString() |
575 | << "\", returned status: " << s.ToString(); |
576 | buf->Unref(); |
577 | return nullptr; |
578 | } |
579 | } |
580 | for (int64_t i = in_n; i < n; ++i) { |
581 | data[i] = ResourceHandle(); |
582 | } |
583 | } |
584 | return buf; |
585 | } |
586 | |
587 | template <> |
588 | TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in, |
589 | int64_t n) { |
590 | CHECK_GT(n, 0); |
591 | Buffer<Variant>* buf = new Buffer<Variant>(a, n); |
592 | Variant* data = buf->template base<Variant>(); |
593 | if (data == nullptr) { |
594 | buf->Unref(); |
595 | return nullptr; |
596 | } |
597 | const int64_t in_n = ProtoHelper<Variant>::NumElements(in); |
598 | if (in_n <= 0) { |
599 | std::fill_n(data, n, Variant()); |
600 | } else { |
601 | // If tensor shape says we have n < in_n elements in the output tensor |
602 | // then make sure to only decode the first n out of the in_n elements in the |
603 | // in tensors. In all other cases, we decode all in_n elements of in and set |
604 | // the remaining elements up to n to be the default Variant() value. |
605 | const int64_t real_n = n < in_n ? n : in_n; |
606 | for (int64_t i = 0; i < real_n; ++i) { |
607 | data[i] = in.variant_val(i); |
608 | if (!DecodeUnaryVariant(&data[i])) { |
609 | LOG(ERROR) << "Could not decode variant with type_name: \"" |
610 | << data[i].TypeName() |
611 | << "\". Perhaps you forgot to register a " |
612 | "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?" ; |
613 | buf->Unref(); |
614 | return nullptr; |
615 | } |
616 | } |
617 | for (int64_t i = in_n; i < n; ++i) { |
618 | data[i] = Variant(); |
619 | } |
620 | } |
621 | return buf; |
622 | } |
623 | |
624 | // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these |
625 | // identical to uint16 but with data stored in half_val instead of int_val (ie., |
626 | // we don't use ProtoHelper<uint16>). |
627 | template <> |
628 | TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in, |
629 | int64_t n) { |
630 | CHECK_GT(n, 0); |
631 | Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n); |
632 | uint16* data = buf->template base<uint16>(); |
633 | if (data == nullptr) { |
634 | buf->Unref(); |
635 | return nullptr; |
636 | } |
637 | const int64_t in_n = in.half_val().size(); |
638 | auto begin = in.half_val().begin(); |
639 | if (n <= in_n) { |
640 | std::copy_n(begin, n, data); |
641 | } else if (in_n > 0) { |
642 | std::copy_n(begin, in_n, data); |
643 | const uint16 last = *(data + in_n - 1); |
644 | std::fill_n(data + in_n, n - in_n, last); |
645 | } else { |
646 | std::fill_n(data, n, 0); |
647 | } |
648 | return buf; |
649 | } |
650 | |
651 | template <> |
652 | TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in, |
653 | int64_t n) { |
654 | CHECK_GT(n, 0); |
655 | Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n); |
656 | uint16* data = buf->template base<uint16>(); |
657 | if (data == nullptr) { |
658 | buf->Unref(); |
659 | return nullptr; |
660 | } |
661 | const int64_t in_n = in.half_val().size(); |
662 | auto begin = in.half_val().begin(); |
663 | if (n <= in_n) { |
664 | std::copy_n(begin, n, data); |
665 | } else if (in_n > 0) { |
666 | std::copy_n(begin, in_n, data); |
667 | const uint16 last = *(data + in_n - 1); |
668 | std::fill_n(data + in_n, n - in_n, last); |
669 | } else { |
670 | std::fill_n(data, n, 0); |
671 | } |
672 | return buf; |
673 | } |
674 | |
675 | // Copies T[n] stored in the buffer "in" into the repeated field in |
676 | // "out" corresponding to type T. |
677 | template <typename T> |
678 | void ToProtoField(const TensorBuffer& in, int64_t n, TensorProto* out) { |
679 | const T* data = in.base<const T>(); |
680 | // NOTE: T may not the same as |
681 | // ProtoHelper<T>::FieldType::value_type. E.g., T==int16, |
682 | // ProtoHelper<T>::FieldType::value_type==int32. If performance is |
683 | // critical, we can specialize T=float and do memcpy directly. |
684 | ProtoHelper<T>::Fill(data, n, out); |
685 | } |
686 | |
687 | void RefIfNonNull(core::RefCounted* buf) { |
688 | if (buf) buf->Ref(); |
689 | } |
690 | |
691 | void UnrefIfNonNull(core::RefCounted* buf) { |
692 | if (buf) buf->Unref(); |
693 | } |
694 | |
695 | } // end namespace |
696 | |
697 | Tensor::Tensor() : Tensor(DT_FLOAT) {} |
698 | |
699 | Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {} |
700 | |
701 | Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) |
702 | : shape_(shape), buf_(buf) { |
703 | set_dtype(type); |
704 | RefIfNonNull(buf); |
705 | } |
706 | |
707 | Tensor::Tensor(DataType type, TensorShape shape, |
708 | core::RefCountPtr<TensorBuffer> buf) |
709 | : shape_(std::move(shape)), buf_(buf.release()) { |
710 | set_dtype(type); |
711 | } |
712 | |
713 | bool Tensor::IsInitialized() const { |
714 | return (buf_ != nullptr && buf_->data() != nullptr) || |
715 | shape_.num_elements() == 0; |
716 | } |
717 | |
718 | void Tensor::CheckType(DataType expected_dtype) const { |
719 | CHECK_EQ(dtype(), expected_dtype) |
720 | << " " << DataTypeString(expected_dtype) << " expected, got " |
721 | << DataTypeString(dtype()); |
722 | } |
723 | |
724 | void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const { |
725 | CHECK_EQ(dtype(), expected_dtype) |
726 | << " " << DataTypeString(expected_dtype) << " expected, got " |
727 | << DataTypeString(dtype()); |
728 | CHECK(IsAligned()) << "ptr = " << base<void>(); |
729 | } |
730 | |
731 | void Tensor::CheckIsAlignedAndSingleElement() const { |
732 | CHECK(IsAligned()) << "Aligned and single element" ; |
733 | CHECK_EQ(1, NumElements()) << "Must have a one element tensor" ; |
734 | } |
735 | |
736 | Tensor::~Tensor() { UnrefIfNonNull(buf_); } |
737 | |
738 | Status Tensor::BitcastFrom(const Tensor& other, DataType dtype, |
739 | const TensorShape& shape) { |
740 | int in_size = DataTypeSize(other.dtype()); |
741 | int out_size = DataTypeSize(dtype); |
742 | if (in_size == 0) { |
743 | return errors::InvalidArgument("other tensor has zero-sized data type" ); |
744 | } |
745 | if (out_size == 0) { |
746 | return errors::InvalidArgument("specified output type is zero-sized" ); |
747 | } |
748 | if (shape.num_elements() * out_size != |
749 | other.shape().num_elements() * in_size) { |
750 | return errors::InvalidArgument( |
751 | "input and output shapes/data type sizes are not compatible" ); |
752 | } |
753 | shape_ = shape; |
754 | shape_.set_data_type(dtype); |
755 | if (buf_ != other.buf_) { |
756 | UnrefIfNonNull(buf_); |
757 | buf_ = other.buf_; |
758 | RefIfNonNull(buf_); |
759 | } |
760 | return OkStatus(); |
761 | } |
762 | |
763 | // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer. |
764 | // For the latter case, we have to make sure that the refcount is |
765 | // one both for the SubBuffer _and_ the underlying TensorBuffer. |
766 | bool Tensor::RefCountIsOne() const { |
767 | return buf_ != nullptr && buf_->RefCountIsOne() && |
768 | buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory(); |
769 | } |
770 | |
771 | // The macro CASES() expands to a switch statement conditioned on |
772 | // TYPE_ENUM. Each case expands the STMTS after a typedef for T. |
773 | #define SINGLE_ARG(...) __VA_ARGS__ |
774 | #define CASE(TYPE, STMTS) \ |
775 | case DataTypeToEnum<TYPE>::value: { \ |
776 | typedef TF_ATTRIBUTE_UNUSED TYPE |
---|