1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/variant_tensor_data.h" |
17 | #include "tensorflow/core/framework/tensor.h" |
18 | #include "tensorflow/core/framework/tensor.pb.h" |
19 | #include "tensorflow/core/lib/strings/strcat.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | VariantTensorData::VariantTensorData(VariantTensorDataProto proto) { |
24 | FromProto(std::move(proto)); |
25 | } |
26 | |
27 | int VariantTensorData::tensors_size() const { return tensors_.size(); } |
28 | |
29 | const Tensor& VariantTensorData::tensors(int index) const { |
30 | return tensors_[index]; |
31 | } |
32 | |
33 | const std::vector<Tensor>& VariantTensorData::tensors() const { |
34 | return tensors_; |
35 | } |
36 | |
37 | Tensor* VariantTensorData::add_tensors() { |
38 | tensors_.emplace_back(); |
39 | return &(tensors_[tensors_.size() - 1]); |
40 | } |
41 | |
42 | void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { |
43 | proto->set_type_name(type_name()); |
44 | proto->set_metadata(metadata_); |
45 | proto->clear_tensors(); |
46 | for (const auto& tensor : tensors_) { |
47 | tensor.AsProtoField(proto->mutable_tensors()->Add()); |
48 | } |
49 | } |
50 | |
51 | bool VariantTensorData::FromProto(VariantTensorDataProto proto) { |
52 | // TODO(ebrevdo): Do this lazily. |
53 | set_type_name(proto.type_name()); |
54 | set_metadata(proto.metadata()); |
55 | for (const auto& tensor : proto.tensors()) { |
56 | Tensor tmp; |
57 | if (!tmp.FromProto(tensor)) return false; |
58 | tensors_.push_back(tmp); |
59 | } |
60 | return true; |
61 | } |
62 | |
63 | bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) { |
64 | set_type_name(proto.type_name()); |
65 | set_metadata(proto.metadata()); |
66 | for (const auto& tensor : proto.tensors()) { |
67 | Tensor tmp; |
68 | if (!tmp.FromProto(tensor)) return false; |
69 | tensors_.push_back(tmp); |
70 | } |
71 | return true; |
72 | } |
73 | |
74 | string VariantTensorData::SerializeAsString() const { |
75 | VariantTensorDataProto proto; |
76 | ToProto(&proto); |
77 | return proto.SerializeAsString(); |
78 | } |
79 | |
80 | bool VariantTensorData::SerializeToString(string* buf) { |
81 | VariantTensorDataProto proto; |
82 | ToProto(&proto); |
83 | return proto.SerializeToString(buf); |
84 | } |
85 | |
86 | bool VariantTensorData::ParseFromString(string s) { |
87 | VariantTensorDataProto proto; |
88 | const bool status = proto.ParseFromString(s); |
89 | if (status) FromProto(std::move(proto)); |
90 | return status; |
91 | } |
92 | |
93 | string VariantTensorData::DebugString() const { |
94 | string repeated_field = "" ; |
95 | for (const auto& t : tensors_) { |
96 | repeated_field = |
97 | strings::StrCat(repeated_field, " tensors: " , t.DebugString()); |
98 | } |
99 | return strings::StrCat("type_name: " , type_name(), " metadata: " , metadata_, |
100 | repeated_field); |
101 | } |
102 | |
103 | string ProtoDebugString(const VariantTensorData& object) { |
104 | return object.DebugString(); |
105 | } |
106 | |
107 | } // namespace tensorflow |
108 | |