1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/kernels/tensor_list.h" |
16 | |
17 | #include "tensorflow/core/framework/tensor_shape.h" |
18 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
19 | #include "tensorflow/core/framework/variant_op_registry.h" |
20 | #include "tensorflow/core/lib/core/coding.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | TensorList::~TensorList() { |
25 | if (tensors_) tensors_->Unref(); |
26 | } |
27 | |
28 | void TensorList::Encode(VariantTensorData* data) const { |
29 | data->set_type_name(TypeName()); |
30 | std::vector<size_t> invalid_indices; |
31 | for (size_t i = 0; i < tensors().size(); i++) { |
32 | if (tensors().at(i).dtype() != DT_INVALID) { |
33 | *data->add_tensors() = tensors().at(i); |
34 | } else { |
35 | invalid_indices.push_back(i); |
36 | } |
37 | } |
38 | string metadata; |
39 | // TODO(b/118838800): Add a proto for storing the metadata. |
40 | // Metadata format: |
41 | // <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto> |
42 | core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size())); |
43 | for (size_t i : invalid_indices) { |
44 | core::PutVarint64(&metadata, static_cast<uint64>(i)); |
45 | } |
46 | core::PutVarint64(&metadata, static_cast<uint64>(element_dtype)); |
47 | core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements)); |
48 | TensorShapeProto element_shape_proto; |
49 | element_shape.AsProto(&element_shape_proto); |
50 | element_shape_proto.AppendToString(&metadata); |
51 | data->set_metadata(metadata); |
52 | } |
53 | |
54 | bool TensorList::Decode(const VariantTensorData& data) { |
55 | // TODO(srbs): Change the signature to Decode(VariantTensorData data) so |
56 | // that we do not have to copy each tensor individually below. This would |
57 | // require changing VariantTensorData::tensors() as well. |
58 | string metadata; |
59 | data.get_metadata(&metadata); |
60 | uint64 scratch; |
61 | StringPiece iter(metadata); |
62 | std::vector<size_t> invalid_indices; |
63 | core::GetVarint64(&iter, &scratch); |
64 | size_t num_invalid_tensors = static_cast<size_t>(scratch); |
65 | invalid_indices.resize(num_invalid_tensors); |
66 | for (size_t i = 0; i < num_invalid_tensors; i++) { |
67 | core::GetVarint64(&iter, &scratch); |
68 | invalid_indices[i] = static_cast<size_t>(scratch); |
69 | } |
70 | |
71 | size_t total_num_tensors = data.tensors().size() + num_invalid_tensors; |
72 | tensors().reserve(total_num_tensors); |
73 | std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin(); |
74 | std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin(); |
75 | for (size_t i = 0; i < total_num_tensors; i++) { |
76 | if (invalid_indices_it != invalid_indices.end() && |
77 | *invalid_indices_it == i) { |
78 | tensors().emplace_back(Tensor(DT_INVALID)); |
79 | invalid_indices_it++; |
80 | } else if (tensors_it != data.tensors().end()) { |
81 | tensors().emplace_back(*tensors_it); |
82 | tensors_it++; |
83 | } else { |
84 | // VariantTensorData is corrupted. |
85 | return false; |
86 | } |
87 | } |
88 | |
89 | core::GetVarint64(&iter, &scratch); |
90 | element_dtype = static_cast<DataType>(scratch); |
91 | core::GetVarint64(&iter, &scratch); |
92 | max_num_elements = static_cast<int>(scratch); |
93 | TensorShapeProto element_shape_proto; |
94 | element_shape_proto.ParseFromString(string(iter.data(), iter.size())); |
95 | element_shape = PartialTensorShape(element_shape_proto); |
96 | return true; |
97 | } |
98 | |
99 | const char TensorList::kTypeName[] = "tensorflow::TensorList" ; |
100 | |
101 | } // namespace tensorflow |
102 | |