1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/composite_tensor_variant.h" |
17 | |
18 | #include "tensorflow/core/framework/variant_op_registry.h" |
19 | #include "tensorflow/core/platform/errors.h" |
20 | #include "tensorflow/core/protobuf/composite_tensor_variant.pb.h" |
21 | #include "tensorflow/core/protobuf/struct.pb.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | constexpr const char CompositeTensorVariant::kTypeName[]; |
26 | |
27 | CompositeTensorVariant::CompositeTensorVariant( |
28 | const CompositeTensorVariantMetadata& metadata, |
29 | absl::Span<Tensor> flat_components) |
30 | : flat_components_(flat_components.begin(), flat_components.end()), |
31 | metadata_(new CompositeTensorVariantMetadata()) { |
32 | *metadata_ = metadata; |
33 | } |
34 | |
35 | CompositeTensorVariant::CompositeTensorVariant() |
36 | : metadata_(new CompositeTensorVariantMetadata()) {} |
37 | |
38 | CompositeTensorVariant::CompositeTensorVariant( |
39 | const CompositeTensorVariant& other) |
40 | : flat_components_(other.flat_components_), |
41 | metadata_(new CompositeTensorVariantMetadata()) { |
42 | *metadata_ = *other.metadata_; |
43 | } |
44 | |
45 | void CompositeTensorVariant::Encode(VariantTensorData* data) const { |
46 | data->set_type_name(TypeName()); |
47 | metadata_->SerializeToString(&data->metadata_string()); |
48 | for (const Tensor& tensor : flat_components_) { |
49 | data->add_tensor(tensor); |
50 | } |
51 | } |
52 | |
53 | bool CompositeTensorVariant::Decode(const VariantTensorData& data) { |
54 | if (!metadata_->ParseFromString(data.metadata_string())) { |
55 | return false; |
56 | } |
57 | flat_components_ = data.tensors(); |
58 | return true; |
59 | } |
60 | |
61 | string CompositeTensorVariant::DebugString() const { |
62 | string result("<CompositeTensorVariant type=" ); |
63 | result.append(TypeSpecProto::TypeSpecClass_Name( |
64 | metadata_->type_spec_proto().type_spec_class())); |
65 | result.append(", components=[" ); |
66 | for (const auto& tensor : flat_components_) { |
67 | if (&tensor != &flat_components_[0]) { |
68 | result.append(", " ); |
69 | } |
70 | result.append(tensor.DebugString()); |
71 | } |
72 | result.append("]>" ); |
73 | return result; |
74 | } |
75 | |
76 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CompositeTensorVariant, |
77 | CompositeTensorVariant::kTypeName); |
78 | |
79 | } // namespace tensorflow |
80 | |