1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ |
18 | |
19 | #include <algorithm> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/lib/core/stringpiece.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class VariantTensorDataProto; |
29 | |
30 | // The serialization format for Variant objects. Objects with references to |
31 | // other Tensors can simply store those tensors in the `tensors` field, and |
32 | // serialize other metadata content in to the `metadata` field. Objects can |
33 | // optionally set the `type_name` for type-checking before deserializing an |
34 | // object. |
35 | // |
36 | // This is the native C++ class equivalent of VariantTensorDataProto. They are |
37 | // separate so that kernels do not need to depend on protos. |
38 | class VariantTensorData { |
39 | public: |
40 | VariantTensorData() = default; |
41 | |
42 | // TODO(b/118823936): This silently returns if the proto is invalid. |
43 | // Consider calling FromProto explicitly instead. |
44 | VariantTensorData(VariantTensorDataProto proto); |
45 | |
46 | // Name of the type of objects being serialized. |
47 | const std::string& type_name() const { return type_name_; } |
48 | void set_type_name(const std::string& type_name) { type_name_ = type_name; } |
49 | |
50 | template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value> |
51 | struct PODResolver {}; |
52 | |
53 | // Portions of the object that are not Tensors. |
54 | // Directly supported types include string POD types. |
55 | template <typename T> |
56 | void set_metadata(const T& value) { |
57 | SetMetadata<T>(value, PODResolver<T>()); |
58 | } |
59 | |
60 | template <typename T> |
61 | bool get_metadata(T* value) const { |
62 | return GetMetadata<T>(value, PODResolver<T>()); |
63 | } |
64 | |
65 | std::string& metadata_string() { return metadata_; } |
66 | |
67 | const std::string& metadata_string() const { return metadata_; } |
68 | |
69 | // Tensors contained within objects being serialized. |
70 | int tensors_size() const; |
71 | const Tensor& tensors(int index) const; |
72 | const std::vector<Tensor>& tensors() const; |
73 | Tensor* add_tensors(); |
74 | |
75 | // A more general version of add_tensors. Parameters are perfectly forwarded |
76 | // to the constructor of the tensor added here. |
77 | template <typename... TensorConstructorArgs> |
78 | Tensor* add_tensor(TensorConstructorArgs&&... args); |
79 | |
80 | // Conversion to and from VariantTensorDataProto |
81 | void ToProto(VariantTensorDataProto* proto) const; |
82 | // This allows optimizations via std::move. |
83 | bool FromProto(VariantTensorDataProto proto); |
84 | bool FromConstProto(const VariantTensorDataProto& proto); |
85 | |
86 | // Serialization via VariantTensorDataProto |
87 | std::string SerializeAsString() const; |
88 | bool SerializeToString(std::string* buf); |
89 | bool ParseFromString(std::string s); |
90 | |
91 | std::string DebugString() const; |
92 | |
93 | public: |
94 | std::string type_name_; |
95 | std::string metadata_; |
96 | std::vector<Tensor> tensors_; |
97 | |
98 | private: |
99 | template <typename T> |
100 | void SetMetadata(const std::string& value, |
101 | PODResolver<T, false /* is_pod */>) { |
102 | metadata_ = value; |
103 | } |
104 | |
105 | template <typename T> |
106 | bool GetMetadata(std::string* value, |
107 | PODResolver<T, false /* is_pod */>) const { |
108 | *value = metadata_; |
109 | return true; |
110 | } |
111 | |
112 | template <typename T> |
113 | void SetMetadata(const T& value, PODResolver<T, true /* is_pod */>) { |
114 | metadata_.assign(reinterpret_cast<const char*>(&value), sizeof(T)); |
115 | } |
116 | |
117 | template <typename T> |
118 | bool GetMetadata(T* value, PODResolver<T, true /* is_pod */>) const { |
119 | if (metadata_.size() != sizeof(T)) return false; |
120 | std::copy_n(metadata_.data(), sizeof(T), reinterpret_cast<char*>(value)); |
121 | return true; |
122 | } |
123 | }; |
124 | |
125 | // For backwards compatibility for when this was a proto |
126 | std::string ProtoDebugString(const VariantTensorData& object); |
127 | |
128 | template <typename... TensorConstructorArgs> |
129 | Tensor* VariantTensorData::add_tensor(TensorConstructorArgs&&... args) { |
130 | tensors_.emplace_back(std::forward<TensorConstructorArgs>(args)...); |
131 | return &tensors_.back(); |
132 | } |
133 | |
134 | } // namespace tensorflow |
135 | |
136 | #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ |
137 | |