1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/kernels/tensor_list.h"
16
17#include "tensorflow/core/framework/tensor_shape.h"
18#include "tensorflow/core/framework/tensor_shape.pb.h"
19#include "tensorflow/core/framework/variant_op_registry.h"
20#include "tensorflow/core/lib/core/coding.h"
21
22namespace tensorflow {
23
24TensorList::~TensorList() {
25 if (tensors_) tensors_->Unref();
26}
27
28void TensorList::Encode(VariantTensorData* data) const {
29 data->set_type_name(TypeName());
30 std::vector<size_t> invalid_indices;
31 for (size_t i = 0; i < tensors().size(); i++) {
32 if (tensors().at(i).dtype() != DT_INVALID) {
33 *data->add_tensors() = tensors().at(i);
34 } else {
35 invalid_indices.push_back(i);
36 }
37 }
38 string metadata;
39 // TODO(b/118838800): Add a proto for storing the metadata.
40 // Metadata format:
41 // <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
42 core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
43 for (size_t i : invalid_indices) {
44 core::PutVarint64(&metadata, static_cast<uint64>(i));
45 }
46 core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
47 core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
48 TensorShapeProto element_shape_proto;
49 element_shape.AsProto(&element_shape_proto);
50 element_shape_proto.AppendToString(&metadata);
51 data->set_metadata(metadata);
52}
53
54bool TensorList::Decode(const VariantTensorData& data) {
55 // TODO(srbs): Change the signature to Decode(VariantTensorData data) so
56 // that we do not have to copy each tensor individually below. This would
57 // require changing VariantTensorData::tensors() as well.
58 string metadata;
59 data.get_metadata(&metadata);
60 uint64 scratch;
61 StringPiece iter(metadata);
62 std::vector<size_t> invalid_indices;
63 core::GetVarint64(&iter, &scratch);
64 size_t num_invalid_tensors = static_cast<size_t>(scratch);
65 invalid_indices.resize(num_invalid_tensors);
66 for (size_t i = 0; i < num_invalid_tensors; i++) {
67 core::GetVarint64(&iter, &scratch);
68 invalid_indices[i] = static_cast<size_t>(scratch);
69 }
70
71 size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
72 tensors().reserve(total_num_tensors);
73 std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
74 std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
75 for (size_t i = 0; i < total_num_tensors; i++) {
76 if (invalid_indices_it != invalid_indices.end() &&
77 *invalid_indices_it == i) {
78 tensors().emplace_back(Tensor(DT_INVALID));
79 invalid_indices_it++;
80 } else if (tensors_it != data.tensors().end()) {
81 tensors().emplace_back(*tensors_it);
82 tensors_it++;
83 } else {
84 // VariantTensorData is corrupted.
85 return false;
86 }
87 }
88
89 core::GetVarint64(&iter, &scratch);
90 element_dtype = static_cast<DataType>(scratch);
91 core::GetVarint64(&iter, &scratch);
92 max_num_elements = static_cast<int>(scratch);
93 TensorShapeProto element_shape_proto;
94 element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
95 element_shape = PartialTensorShape(element_shape_proto);
96 return true;
97}
98
99const char TensorList::kTypeName[] = "tensorflow::TensorList";
100
101} // namespace tensorflow
102